// Copyright (c) 2010-2025, Lawrence Livermore National Security, LLC. Produced
// at the Lawrence Livermore National Laboratory. All Rights reserved. See files
// LICENSE and NOTICE for details. LLNL-CODE-806117.
//
// This file is part of the MFEM library. For more information and source code
// availability visit https://mfem.org.
//
// MFEM is free software; you can redistribute it and/or modify it under the
// terms of the BSD-3 license. We welcome feedback and contributions, see file
// CONTRIBUTING.md for details.

#include "../config/config.hpp"

#ifdef MFEM_USE_MPI

#include "mesh_headers.hpp"
#include "pncmesh.hpp"
#include "../general/binaryio.hpp"
#include "../general/communication.hpp"

#include <numeric> // std::accumulate
#include <map>
#include <climits> // INT_MIN, INT_MAX
#include <array>

namespace mfem
{

using namespace bin_io;

ParNCMesh::ParNCMesh(MPI_Comm comm, const NCMesh &ncmesh,
                     const int *partitioning)
   : NCMesh(ncmesh)
{
   MyComm = comm;
   MPI_Comm_size(MyComm, &NRanks);
   MPI_Comm_rank(MyComm, &MyRank);

   // assign leaf elements to the processors by simply splitting the
   // sequence of leaf elements into 'NRanks' parts
   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      elements[leaf_elements[i]].rank =
         partitioning ? partitioning[i] : InitialPartition(i);
   }

   Update();

   // note that at this point all processors still have all the leaf elements;
   // we however may now start pruning the refinement tree to get rid of
   // branches that only contain someone else's leaves (see Prune())
}

ParNCMesh::ParNCMesh(MPI_Comm comm, std::istream &input, int version,
                     int &curved, int &is_nc)
   : NCMesh(input, version, curved, is_nc)
{
   MFEM_VERIFY(version != 11, "Nonconforming mesh format \"MFEM NC mesh v1.1\""
               " is supported only in serial.");

   MyComm = comm;
   MPI_Comm_size(MyComm, &NRanks);

   int my_rank;
   MPI_Comm_rank(MyComm, &my_rank);

   int max_rank = 0;
   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      max_rank = std::max(elements[leaf_elements[i]].rank, max_rank);
   }

   MFEM_VERIFY((my_rank == MyRank) && (max_rank < NRanks),
               "Parallel mesh file doesn't seem to match current MPI setup. "
               "Loading a parallel NC mesh with a non-matching communicator "
               "size is not supported.");

   bool iso = Iso;
   MPI_Allreduce(&iso, &Iso, 1, MFEM_MPI_CXX_BOOL, MPI_LAND, MyComm);

   Update();
}

ParNCMesh::ParNCMesh(const ParNCMesh &other)
// copy primary data only
   : NCMesh(other)
   , MyComm(other.MyComm)
   , NRanks(other.NRanks)
{
   Update(); // mark all secondary stuff for recalculation
}

ParNCMesh::~ParNCMesh()
{
   ClearAuxPM();
}

void ParNCMesh::Update()
{
   NCMesh::Update();

   groups.clear();
   group_id.clear();

   CommGroup self;
   self.push_back(MyRank);
   groups.push_back(self);
   group_id[self] = 0;

   for (int i = 0; i < 3; i++)
   {
      entity_owner[i].DeleteAll();
      entity_pmat_group[i].DeleteAll();
      entity_index_rank[i].DeleteAll();
      entity_conf_group[i].DeleteAll();
      entity_elem_local[i].DeleteAll();
   }

   shared_vertices.Clear();
   shared_edges.Clear();
   shared_faces.Clear();

   element_type.SetSize(0);
   ghost_layer.SetSize(0);
   boundary_layer.SetSize(0);
}

void ParNCMesh::ElementSharesFace(int elem, int local, int face)
{
   // Analogous to ElementSharesEdge.

   Element &el = elements[elem];
   int f_index = faces[face].index;

   int &owner = tmp_owner[f_index];
   owner = std::min(owner, el.rank);

   char &flag = tmp_shared_flag[f_index];
   flag |= (el.rank == MyRank) ? 0x1 : 0x2;

   entity_index_rank[2].Append(Connection(f_index, el.rank));

   // derive globally consistent face ID from the global element sequence
   int &el_loc = entity_elem_local[2][f_index];
   if (el_loc < 0 || leaf_sfc_index[el.index] < leaf_sfc_index[(el_loc >> 4)])
   {
      el_loc = (el.index << 4) | local;
   }
}

void ParNCMesh::BuildFaceList()
{
   if (HaveTets()) { GetEdgeList(); } // needed by TraverseTetEdge()

   // This is an extension of NCMesh::BuildFaceList() which also determines
   // face ownership and prepares face processor groups.

   face_list.Clear();
   shared_faces.Clear();
   boundary_faces.SetSize(0);

   if (Dim < 3 || !leaf_elements.Size()) { return; }

   int nfaces = NFaces + NGhostFaces;

   tmp_owner.SetSize(nfaces);
   tmp_owner = INT_MAX;

   tmp_shared_flag.SetSize(nfaces);
   tmp_shared_flag = 0;

   entity_index_rank[2].SetSize(6*leaf_elements.Size() * 3/2);
   entity_index_rank[2].SetSize(0);

   entity_elem_local[2].SetSize(nfaces);
   entity_elem_local[2] = -1;

   NCMesh::BuildFaceList();

   InitOwners(nfaces, entity_owner[2]);
   MakeSharedList(face_list, shared_faces);

   tmp_owner.DeleteAll();
   tmp_shared_flag.DeleteAll();

   // create simple conforming (cut-mesh) groups now
   CreateGroups(NFaces, entity_index_rank[2], entity_conf_group[2]);
   // NOTE: entity_index_rank[2] is not deleted until CalculatePMatrixGroups

   CalcFaceOrientations();
}

void ParNCMesh::ElementSharesEdge(int elem, int local, int enode)
{
   // Called by NCMesh::BuildEdgeList when an edge is visited in a leaf element.
   // This allows us to determine edge ownership and whether it is shared
   // without duplicating all the HashTable lookups in NCMesh::BuildEdgeList().

   Element &el= elements[elem];
   int e_index = nodes[enode].edge_index;

   int &owner = tmp_owner[e_index];
   owner = std::min(owner, el.rank);

   char &flag = tmp_shared_flag[e_index];
   flag |= (el.rank == MyRank) ? 0x1 : 0x2;

   entity_index_rank[1].Append(Connection(e_index, el.rank));

   // derive globally consistent edge ID from the global element sequence
   int &el_loc = entity_elem_local[1][e_index];
   if (el_loc < 0 || leaf_sfc_index[el.index] < leaf_sfc_index[(el_loc >> 4)])
   {
      el_loc = (el.index << 4) | local;
   }
}

void ParNCMesh::BuildEdgeList()
{
   // This is an extension of NCMesh::BuildEdgeList() which also determines
   // edge ownership and prepares edge processor groups.

   edge_list.Clear();
   shared_edges.Clear();
   if (Dim < 3) { boundary_faces.SetSize(0); }

   if (Dim < 2 || !leaf_elements.Size()) { return; }

   int nedges = NEdges + NGhostEdges;

   tmp_owner.SetSize(nedges);
   tmp_owner = INT_MAX;

   tmp_shared_flag.SetSize(nedges);
   tmp_shared_flag = 0;

   entity_index_rank[1].SetSize(12*leaf_elements.Size() * 3/2);
   entity_index_rank[1].SetSize(0);

   entity_elem_local[1].SetSize(nedges);
   entity_elem_local[1] = -1;

   NCMesh::BuildEdgeList();

   InitOwners(nedges, entity_owner[1]);
   MakeSharedList(edge_list, shared_edges);

   tmp_owner.DeleteAll();
   tmp_shared_flag.DeleteAll();

   // create simple conforming (cut-mesh) groups now
   CreateGroups(NEdges, entity_index_rank[1], entity_conf_group[1]);
   // NOTE: entity_index_rank[1] is not deleted until CalculatePMatrixGroups
}

void ParNCMesh::FindEdgesOfGhostFace(int face, Array<int> & edges)
{
   const NCList &faceList = GetFaceList();
   NCList::MeshIdAndType midt = faceList.GetMeshIdAndType(face);
   if (!midt.id)
   {
      edges.SetSize(0);
      return;
   }

   int V[4], E[4], Eo[4];
   const int nfv = GetFaceVerticesEdges(*midt.id, V, E, Eo);
   MFEM_ASSERT(nfv == 4, "");

   edges.SetSize(nfv);
   for (int i=0; i<nfv; ++i)
   {
      edges[i] = E[i];
   }
}

void ParNCMesh::FindEdgesOfGhostElement(int elem, Array<int> & edges)
{
   NCMesh::Element &el = elements[elem]; // ghost element
   MFEM_ASSERT(el.rank != MyRank, "");

   MFEM_ASSERT(!el.ref_type, "not a leaf element.");

   GeomInfo& gi = GI[el.Geom()];
   edges.SetSize(gi.ne);

   for (int j = 0; j < gi.ne; j++)
   {
      // get node for this edge
      const int* ev = gi.edges[j];
      int node[2] = { el.node[ev[0]], el.node[ev[1]] };

      int enode = nodes.FindId(node[0], node[1]);
      MFEM_ASSERT(enode >= 0, "edge node not found!");

      Node &nd = nodes[enode];
      MFEM_ASSERT(nd.HasEdge(), "edge not found!");

      edges[j] = nd.edge_index;
   }
}

void ParNCMesh::FindFacesOfGhostElement(int elem, Array<int> & faces)
{
   NCMesh::Element &el = elements[elem]; // ghost element
   MFEM_ASSERT(el.rank != MyRank, "");
   MFEM_ASSERT(!el.ref_type, "not a leaf element.");

   faces.SetSize(GI[el.Geom()].nf);
   for (int j = 0; j < faces.Size(); j++)
   {
      faces[j] = GetFace(el, j)->index;
   }
}

void ParNCMesh::ElementSharesVertex(int elem, int local, int vnode)
{
   // Analogous to ElementSharesEdge.

   Element &el = elements[elem];
   int v_index = nodes[vnode].vert_index;

   int &owner = tmp_owner[v_index];
   owner = std::min(owner, el.rank);

   char &flag = tmp_shared_flag[v_index];
   flag |= (el.rank == MyRank) ? 0x1 : 0x2;

   entity_index_rank[0].Append(Connection(v_index, el.rank));

   // derive globally consistent vertex ID from the global element sequence
   int &el_loc = entity_elem_local[0][v_index];
   if (el_loc < 0 || leaf_sfc_index[el.index] < leaf_sfc_index[(el_loc >> 4)])
   {
      el_loc = (el.index << 4) | local;
   }
}

void ParNCMesh::BuildVertexList()
{
   // This is an extension of NCMesh::BuildVertexList() which also determines
   // vertex ownership and creates vertex processor groups.

   int nvertices = NVertices + NGhostVertices;

   tmp_owner.SetSize(nvertices);
   tmp_owner = INT_MAX;

   tmp_shared_flag.SetSize(nvertices);
   tmp_shared_flag = 0;

   entity_index_rank[0].SetSize(8*leaf_elements.Size());
   entity_index_rank[0].SetSize(0);

   entity_elem_local[0].SetSize(nvertices);
   entity_elem_local[0] = -1;

   NCMesh::BuildVertexList();

   InitOwners(nvertices, entity_owner[0]);
   MakeSharedList(vertex_list, shared_vertices);

   tmp_owner.DeleteAll();
   tmp_shared_flag.DeleteAll();

   // create simple conforming (cut-mesh) groups now
   CreateGroups(NVertices, entity_index_rank[0], entity_conf_group[0]);
   // NOTE: entity_index_rank[0] is not deleted until CalculatePMatrixGroups
}

void ParNCMesh::InitOwners(int num, Array<GroupId> &entity_owner_)
{
   entity_owner_.SetSize(num);
   for (int i = 0; i < num; i++)
   {
      entity_owner_[i] =
         (tmp_owner[i] != INT_MAX) ? GetSingletonGroup(tmp_owner[i]) : 0;
   }
}

void ParNCMesh::MakeSharedList(const NCList &list, NCList &shared)
{
   MFEM_VERIFY(tmp_shared_flag.Size(), "wrong code path");

   // combine flags of masters and slaves
   for (int i = 0; i < list.masters.Size(); i++)
   {
      const Master &master = list.masters[i];
      char &master_flag = tmp_shared_flag[master.index];
      char master_old_flag = master_flag;

      for (int j = master.slaves_begin; j < master.slaves_end; j++)
      {
         int si = list.slaves[j].index;
         if (si >= 0)
         {
            char &slave_flag = tmp_shared_flag[si];
            master_flag |= slave_flag;
            slave_flag |= master_old_flag;
         }
         else // special case: prism edge-face constraint
         {
            if (entity_owner[1][-1-si] != MyRank)
            {
               master_flag |= 0x2;
            }
         }
      }
   }

   shared.Clear();

   for (int i = 0; i < list.conforming.Size(); i++)
   {
      if (tmp_shared_flag[list.conforming[i].index] == 0x3)
      {
         shared.conforming.Append(list.conforming[i]);
      }
   }
   for (int i = 0; i < list.masters.Size(); i++)
   {
      if (tmp_shared_flag[list.masters[i].index] == 0x3)
      {
         shared.masters.Append(list.masters[i]);
      }
   }
   for (int i = 0; i < list.slaves.Size(); i++)
   {
      int si = list.slaves[i].index;
      if (si >= 0 && tmp_shared_flag[si] == 0x3)
      {
         shared.slaves.Append(list.slaves[i]);
      }
   }
}

bool operator<(const ParNCMesh::CommGroup &lhs, const ParNCMesh::CommGroup &rhs)
{
   if (lhs.size() == rhs.size())
   {
      for (unsigned i = 0; i < lhs.size(); i++)
      {
         if (lhs[i] < rhs[i]) { return true; }
      }
      return false;
   }
   return lhs.size() < rhs.size();
}

#ifdef MFEM_DEBUG
static bool group_sorted(const ParNCMesh::CommGroup &group)
{
   for (unsigned i = 1; i < group.size(); i++)
   {
      if (group[i] <= group[i-1]) { return false; }
   }
   return true;
}
#endif

ParNCMesh::GroupId ParNCMesh::GetGroupId(const CommGroup &group)
{
   if (group.size() == 1 && group[0] == MyRank)
   {
      return 0;
   }
   MFEM_ASSERT(group_sorted(group), "invalid group");
   GroupId &id = group_id[group];
   if (!id)
   {
      id = groups.size();
      groups.push_back(group);
   }
   return id;
}

ParNCMesh::GroupId ParNCMesh::GetSingletonGroup(int rank)
{
   MFEM_ASSERT(rank != INT_MAX, "invalid rank");
   static std::vector<int> group;
   group.resize(1);
   group[0] = rank;
   return GetGroupId(group);
}

bool ParNCMesh::GroupContains(GroupId id, int rank) const
{
   // TODO: would std::lower_bound() pay off here? Groups are usually small.
   const CommGroup &group = groups[id];
   for (unsigned i = 0; i < group.size(); i++)
   {
      if (group[i] == rank) { return true; }
   }
   return false;
}

void ParNCMesh::CreateGroups(int nentities, Array<Connection> &index_rank,
                             Array<GroupId> &entity_group)
{
   index_rank.Sort();
   index_rank.Unique();

   entity_group.SetSize(nentities);
   entity_group = 0;

   CommGroup group;

   for (auto begin = index_rank.begin(); begin != index_rank.end(); /* nothing */)
   {
      const auto &index = begin->from;
      if (index >= nentities) { break; }

      // Locate the next connection that is not from this index
      const auto end = std::find_if(begin, index_rank.end(),
      [&index](const mfem::Connection &c) { return c.from != index;});

      // For each connection from this index, collect the ranks connected.
      group.resize(std::distance(begin, end));
      std::transform(begin, end, group.begin(), [](const mfem::Connection &c) { return c.to; });

      // assign this entity's group and advance the search start
      entity_group[index] = GetGroupId(group);
      begin = end;
   }
}

void ParNCMesh::AddConnections(int entity, int index, const Array<int> &ranks)
{
   for (auto rank : ranks)
   {
      entity_index_rank[entity].Append(Connection(index, rank));
   }
}

void ParNCMesh::CalculatePMatrixGroups()
{
   // make sure all entity_index_rank[i] arrays are filled
   GetSharedVertices();
   GetSharedEdges();
   GetSharedFaces();

   int v[4], e[4], eo[4];

   Array<int> ranks;
   ranks.Reserve(256);

   // connect slave edges to master edges and their vertices
   for (const auto &master_edge : shared_edges.masters)
   {
      ranks.SetSize(0);
      for (int j = master_edge.slaves_begin; j < master_edge.slaves_end; j++)
      {
         int owner = entity_owner[1][edge_list.slaves[j].index];
         ranks.Append(groups[owner][0]);
      }
      ranks.Sort();
      ranks.Unique();

      AddConnections(1, master_edge.index, ranks);

      GetEdgeVertices(master_edge, v);
      for (int j = 0; j < 2; j++)
      {
         AddConnections(0, v[j], ranks);
      }
   }

   // connect slave faces to master faces and their edges and vertices
   for (const auto &master_face : shared_faces.masters)
   {
      ranks.SetSize(0);
      for (int j = master_face.slaves_begin; j < master_face.slaves_end; j++)
      {
         int si = face_list.slaves[j].index;
         int owner = (si >= 0) ? entity_owner[2][si] // standard face dependency
                     /*     */ : entity_owner[1][-1 - si]; // prism edge-face dep
         ranks.Append(groups[owner][0]);
      }
      ranks.Sort();
      ranks.Unique();

      AddConnections(2, master_face.index, ranks);

      int nfv = GetFaceVerticesEdges(master_face, v, e, eo);
      for (int j = 0; j < nfv; j++)
      {
         AddConnections(0, v[j], ranks);
         AddConnections(1, e[j], ranks);
      }
   }

   int nentities[3] =
   {
      NVertices + NGhostVertices,
      NEdges + NGhostEdges,
      NFaces + NGhostFaces
   };

   // compress the index-rank arrays into group representation
   for (int i = 0; i < 3; i++)
   {
      CreateGroups(nentities[i], entity_index_rank[i], entity_pmat_group[i]);
      entity_index_rank[i].DeleteAll();
   }
}

int ParNCMesh::get_face_orientation(const Face &face, const Element &e1,
                                    const Element &e2,
                                    int local[2])
{
   // Return face orientation in e2, assuming the face has orientation 0 in e1.
   int ids[2][4];
   const Element * const e[2] = { &e1, &e2 };
   for (int i = 0; i < 2; i++)
   {
      // get local face number (remember that p1, p2, p3 are not in order, and
      // p4 is not stored)
      int lf = find_local_face(e[i]->Geom(),
                               find_node(*e[i], face.p1),
                               find_node(*e[i], face.p2),
                               find_node(*e[i], face.p3));
      // optional output
      if (local) { local[i] = lf; }

      // get node IDs for the face as seen from e[i]
      const int* fv = GI[e[i]->Geom()].faces[lf];
      for (int j = 0; j < 4; j++)
      {
         ids[i][j] = e[i]->node[fv[j]];
      }
   }

   return (ids[0][3] >= 0) ? Mesh::GetQuadOrientation(ids[0], ids[1])
          /*            */ : Mesh::GetTriOrientation(ids[0], ids[1]);
}

void ParNCMesh::CalcFaceOrientations()
{
   if (Dim < 3) { return; }

   // Calculate orientation of shared conforming faces.
   // NOTE: face orientation is calculated relative to its lower rank element.
   // Thanks to the ghost layer this can be done locally, without communication.

   face_orient.SetSize(NFaces);
   face_orient = 0;

   for (const auto &face : faces)
   {
      if (face.elem[0] >= 0 && face.elem[1] >= 0 && face.index < NFaces)
      {
         Element *e1 = &elements[face.elem[0]];
         Element *e2 = &elements[face.elem[1]];

         if (e1->rank == e2->rank) { continue; }
         if (e1->rank > e2->rank) { std::swap(e1, e2); }

         face_orient[face.index] = get_face_orientation(face, *e1, *e2);
      }
   }
}

void ParNCMesh::GetBoundaryClosure(const Array<int> &bdr_attr_is_ess,
                                   Array<int> &bdr_vertices,
                                   Array<int> &bdr_edges, Array<int> &bdr_faces)
{
   NCMesh::GetBoundaryClosure(bdr_attr_is_ess, bdr_vertices, bdr_edges, bdr_faces);

   if (Dim == 3)
   {
      // Mark masters of shared slave boundary faces as essential boundary
      // faces. Some master faces may only have slave children.
      for (const auto &mf : shared_faces.masters)
      {
         if (elements[mf.element].rank != MyRank) { continue; }
         for (int j = mf.slaves_begin; j < mf.slaves_end; j++)
         {
            const auto &sf = GetFaceList().slaves[j];
            if (sf.index < 0)
            {
               // Edge-face constraint. Skip this edge.
               continue;
            }
            const Face &face = *GetFace(elements[sf.element], sf.local);
            if (face.Boundary() && bdr_attr_is_ess[face.attribute - 1])
            {
               bdr_faces.Append(mf.index);
            }
         }
      }
   }
   else if (Dim == 2)
   {
      // Mark masters of shared slave boundary edges as essential boundary
      // edges. Some master edges may only have slave children.
      for (const auto &me : shared_edges.masters)
      {
         if (elements[me.element].rank != MyRank) { continue; }
         for (int j = me.slaves_begin; j < me.slaves_end; j++)
         {
            const auto &se = GetEdgeList().slaves[j];
            Face *face = GetFace(elements[se.element], se.local);
            if (face && face->Boundary() && bdr_attr_is_ess[face->attribute - 1])
            {
               bdr_edges.Append(me.index);
            }
         }
      }
   }

   // Filter, sort and unique an array, so it contains only local unique values.
   auto FilterSortUnique = [](Array<int> &v, int N)
   {
      // Perform the O(N) filter before the O(NlogN) sort.
      auto local = std::remove_if(v.begin(), v.end(), [N](int i) { return i >= N; });
      std::sort(v.begin(), local);
      v.SetSize(std::distance(v.begin(), std::unique(v.begin(), local)));
   };

   FilterSortUnique(bdr_vertices, NVertices);
   FilterSortUnique(bdr_edges, NEdges);
   FilterSortUnique(bdr_faces, NFaces);
}


//// Neighbors /////////////////////////////////////////////////////////////////

void ParNCMesh::UpdateLayers()
{
   if (element_type.Size()) { return; }

   int nleaves = leaf_elements.Size();

   element_type.SetSize(nleaves);
   for (int i = 0; i < nleaves; i++)
   {
      element_type[i] = (elements[leaf_elements[i]].rank == MyRank) ? 1 : 0;
   }

   // determine the ghost layer
   Array<char> ghost_set;
   FindSetNeighbors(element_type, NULL, &ghost_set);

   // find the neighbors of the ghost layer
   Array<char> boundary_set;
   FindSetNeighbors(ghost_set, NULL, &boundary_set);

   ghost_layer.SetSize(0);
   boundary_layer.SetSize(0);
   for (int i = 0; i < nleaves; i++)
   {
      char &etype = element_type[i];
      if (ghost_set[i])
      {
         etype = 2;
         ghost_layer.Append(leaf_elements[i]);
      }
      else if (boundary_set[i] && etype)
      {
         etype = 3;
         boundary_layer.Append(leaf_elements[i]);
      }
   }
}

bool ParNCMesh::CheckElementType(int elem, int type)
{
   Element &el = elements[elem];
   if (!el.ref_type)
   {
      return (element_type[el.index] == type);
   }
   else
   {
      for (int i = 0; i < 8 && el.child[i] >= 0; i++)
      {
         if (!CheckElementType(el.child[i], type)) { return false; }
      }
      return true;
   }
}

void ParNCMesh::ElementNeighborProcessors(int elem, Array<int> &ranks)
{
   ranks.SetSize(0); // preserve capacity

   // big shortcut: there are no neighbors if element_type == 1
   if (CheckElementType(elem, 1)) { return; }

   // ok, we do need to look for neighbors;
   // at least we can only search in the ghost layer
   tmp_neighbors.SetSize(0);
   FindNeighbors(elem, tmp_neighbors, &ghost_layer);

   // return a list of processors
   for (int i = 0; i < tmp_neighbors.Size(); i++)
   {
      ranks.Append(elements[tmp_neighbors[i]].rank);
   }
   ranks.Sort();
   ranks.Unique();
}

template<class T>
static void set_to_array(const std::set<T> &set, Array<T> &array)
{
   array.Reserve(static_cast<int>(set.size()));
   array.SetSize(0);
   for (auto x : set)
   {
      array.Append(x);
   }
}

void ParNCMesh::NeighborProcessors(Array<int> &neighbors)
{
   UpdateLayers();

   // TODO: look at groups instead?

   std::set<int> ranks;
   for (int i = 0; i < ghost_layer.Size(); i++)
   {
      ranks.insert(elements[ghost_layer[i]].rank);
   }
   set_to_array(ranks, neighbors);
}


//// ParMesh compatibility /////////////////////////////////////////////////////

void ParNCMesh::MakeSharedTable(int ngroups, int ent, Array<int> &shared_local,
                                Table &group_shared, Array<char> *entity_geom,
                                char geom)
{
   const Array<GroupId> &conf_group = entity_conf_group[ent];

   group_shared.MakeI(ngroups-1);

   // count shared entities
   int num_shared = 0;
   for (int i = 0; i < conf_group.Size(); i++)
   {
      if (conf_group[i])
      {
         if (entity_geom && (*entity_geom)[i] != geom) { continue; }

         num_shared++;
         group_shared.AddAColumnInRow(conf_group[i]-1);
      }
   }

   shared_local.SetSize(num_shared);
   group_shared.MakeJ();

   // fill shared_local and group_shared
   for (int i = 0, j = 0; i < conf_group.Size(); i++)
   {
      if (conf_group[i])
      {
         if (entity_geom && (*entity_geom)[i] != geom) { continue; }

         shared_local[j] = i;
         group_shared.AddConnection(conf_group[i]-1, j);
         j++;
      }
   }
   group_shared.ShiftUpI();

   // sort the groups consistently across processors
   for (int i = 0; i < group_shared.Size(); i++)
   {
      int size = group_shared.RowSize(i);
      int *row = group_shared.GetRow(i);

      Array<int> ref_row(row, size);
      ref_row.Sort([&](const int a, const int b)
      {
         int el_loc_a = entity_elem_local[ent][shared_local[a]];
         int el_loc_b = entity_elem_local[ent][shared_local[b]];

         int lsi_a = leaf_sfc_index[el_loc_a >> 4];
         int lsi_b = leaf_sfc_index[el_loc_b >> 4];

         if (lsi_a != lsi_b) { return lsi_a < lsi_b; }

         return (el_loc_a & 0xf) < (el_loc_b & 0xf);
      });
   }
}

void ParNCMesh::GetConformingSharedStructures(ParMesh &pmesh)
{
   if (leaf_elements.Size())
   {
      // make sure we have entity_conf_group[x] and the ordering arrays
      for (int ent = 0; ent < Dim; ent++)
      {
         GetSharedList(ent);
         MFEM_VERIFY(entity_conf_group[ent].Size() ||
                     pmesh.GetNE() == 0, "Non empty partitions must be connected");
         MFEM_VERIFY(entity_elem_local[ent].Size() ||
                     pmesh.GetNE() == 0, "Non empty partitions must be connected");
      }
   }

   // create ParMesh groups, and the map (ncmesh_group -> pmesh_group)
   Array<int> group_map(static_cast<int>(groups.size()));
   {
      group_map = 0;
      IntegerSet iset;
      ListOfIntegerSets int_groups;
      for (unsigned i = 0; i < groups.size(); i++)
      {
         if (groups[i].size() > 1 || !i) // skip singleton groups
         {
            iset.Recreate(static_cast<int>(groups[i].size()), groups[i].data());
            group_map[i] = int_groups.Insert(iset);
         }
      }
      pmesh.gtopo.Create(int_groups, 822);
   }

   // renumber groups in entity_conf_group[] (due to missing singletons)
   for (int ent = 0; ent < 3; ent++)
   {
      for (int i = 0; i < entity_conf_group[ent].Size(); i++)
      {
         GroupId &ecg = entity_conf_group[ent][i];
         ecg = group_map[ecg];
      }
   }

   // create shared to local index mappings and group tables
   int ng = pmesh.gtopo.NGroups();
   MakeSharedTable(ng, 0, pmesh.svert_lvert, pmesh.group_svert);
   MakeSharedTable(ng, 1, pmesh.sedge_ledge, pmesh.group_sedge);

   Array<int> slt, slq;
   MakeSharedTable(ng, 2, slt, pmesh.group_stria, &face_geom, Geometry::TRIANGLE);
   MakeSharedTable(ng, 2, slq, pmesh.group_squad, &face_geom, Geometry::SQUARE);

   pmesh.sface_lface = slt;
   pmesh.sface_lface.Append(slq);

   // create shared_edges
   for (int i = 0; i < pmesh.shared_edges.Size(); i++)
   {
      delete pmesh.shared_edges[i];
   }
   pmesh.shared_edges.SetSize(pmesh.sedge_ledge.Size());
   for (int i = 0; i < pmesh.shared_edges.Size(); i++)
   {
      int el_loc = entity_elem_local[1][pmesh.sedge_ledge[i]];
      MeshId edge_id(-1, leaf_elements[(el_loc >> 4)], (el_loc & 0xf));

      int v[2];
      GetEdgeVertices(edge_id, v, false);
      pmesh.shared_edges[i] = new Segment(v, 1);
   }

   // create shared_trias
   pmesh.shared_trias.SetSize(slt.Size());
   for (int i = 0; i < slt.Size(); i++)
   {
      int el_loc = entity_elem_local[2][slt[i]];
      MeshId face_id(-1, leaf_elements[(el_loc >> 4)], (el_loc & 0xf));

      int v[4], e[4], eo[4];
      GetFaceVerticesEdges(face_id, v, e, eo);
      pmesh.shared_trias[i].Set(v);
   }

   // create shared_quads
   pmesh.shared_quads.SetSize(slq.Size());
   for (int i = 0; i < slq.Size(); i++)
   {
      int el_loc = entity_elem_local[2][slq[i]];
      MeshId face_id(-1, leaf_elements[(el_loc >> 4)], (el_loc & 0xf));

      int e[4], eo[4];
      GetFaceVerticesEdges(face_id, pmesh.shared_quads[i].v, e, eo);
   }

   // free the arrays, they're not needed anymore (until next mesh update)
   for (int ent = 0; ent < Dim; ent++)
   {
      entity_conf_group[ent].DeleteAll();
      entity_elem_local[ent].DeleteAll();
   }
}

void ParNCMesh::GetFaceNeighbors(ParMesh &pmesh)
{
   ClearAuxPM();

   const NCList &shared = (Dim == 3) ? GetSharedFaces() : GetSharedEdges();
   const NCList &full_list = (Dim == 3) ? GetFaceList() : GetEdgeList();

   Array<Element*> fnbr;
   Array<Connection> send_elems;
   std::map<int, std::vector<int>> recv_elems;

   // Counts the number of slave faces of a master. This may be larger than the
   // number of shared slaves if there exist degenerate slave-faces from
   // face-edge constraints.
   auto count_slaves = [&](int i, const Master& x)
   {
      return i + (x.slaves_end - x.slaves_begin);
   };

   const int bound = shared.conforming.Size() + std::accumulate(
                        shared.masters.begin(), shared.masters.end(),
                        0, count_slaves);

   fnbr.Reserve(bound);
   send_elems.Reserve(bound);

   // If there are face neighbor elements with triangular faces, the
   // `face_nbr_el_ori` structure will need to be built. This requires
   // communication so we attempt to avoid it by checking first.
   bool face_nbr_w_tri_faces = false;

   // go over all shared faces and collect face neighbor elements
   for (int i = 0; i < shared.conforming.Size(); i++)
   {
      const MeshId &cf = shared.conforming[i];
      Face* face = GetFace(elements[cf.element], cf.local);
      MFEM_ASSERT(face != NULL, "");

      MFEM_ASSERT(face->elem[0] >= 0 && face->elem[1] >= 0, "");
      Element* e[2] = { &elements[face->elem[0]], &elements[face->elem[1]] };

      if (e[0]->rank == MyRank) { std::swap(e[0], e[1]); }
      MFEM_ASSERT(e[0]->rank != MyRank && e[1]->rank == MyRank, "");

      face_nbr_w_tri_faces |= !Geometry::IsTensorProduct(Geometry::Type(e[0]->geom));
      face_nbr_w_tri_faces |= !Geometry::IsTensorProduct(Geometry::Type(e[1]->geom));

      fnbr.Append(e[0]);
      send_elems.Append(Connection(e[0]->rank, e[1]->index));
      recv_elems[e[0]->rank].push_back(e[0]->index);
   }

   for (int i = 0; i < shared.masters.Size(); i++)
   {
      const Master &mf = shared.masters[i];
      for (int j = mf.slaves_begin; j < mf.slaves_end; j++)
      {
         const Slave &sf = full_list.slaves[j];
         if (sf.element < 0 || sf.index < 0) { continue; }

         MFEM_ASSERT(mf.element >= 0, "");
         Element* e[2] = { &elements[mf.element], &elements[sf.element] };

         bool loc0 = (e[0]->rank == MyRank);
         bool loc1 = (e[1]->rank == MyRank);
         if (loc0 == loc1)
         {
            // neither or both of these elements are on this rank.
            continue;
         }
         if (loc0) { std::swap(e[0], e[1]); }

         face_nbr_w_tri_faces |= !Geometry::IsTensorProduct(Geometry::Type(e[0]->geom));
         face_nbr_w_tri_faces |= !Geometry::IsTensorProduct(Geometry::Type(e[1]->geom));

         fnbr.Append(e[0]);
         send_elems.Append(Connection(e[0]->rank, e[1]->index));
         recv_elems[e[0]->rank].push_back(e[0]->index);
      }
   }

   MFEM_ASSERT(fnbr.Size() <= bound,
               "oops, bad upper bound. fnbr.Size(): " << fnbr.Size() << ", bound: " << bound);

   // remove duplicate face neighbor elements and sort them by rank & index
   // (note that the send table is sorted the same way and the order is also the
   // same on different processors, this is important for ExchangeFaceNbrData)
   fnbr.Sort();
   fnbr.Unique();
   fnbr.Sort([](const Element* a, const Element* b)
   {
      return (a->rank != b->rank) ? a->rank < b->rank
             /*                */ : a->index < b->index;
   });

   // put the ranks into 'face_nbr_group'
   for (int i = 0; i < fnbr.Size(); i++)
   {
      if (!i || fnbr[i]->rank != pmesh.face_nbr_group.Last())
      {
         pmesh.face_nbr_group.Append(fnbr[i]->rank);
      }
   }
   const int nranks = pmesh.face_nbr_group.Size();

   // create a new mfem::Element for each face neighbor element
   pmesh.face_nbr_elements.SetSize(0);
   pmesh.face_nbr_elements.Reserve(fnbr.Size());

   pmesh.face_nbr_elements_offset.SetSize(0);
   pmesh.face_nbr_elements_offset.Reserve(pmesh.face_nbr_group.Size()+1);

   Array<int> fnbr_index(NGhostElements);
   fnbr_index = -1;

   std::map<int, int> vert_map;
   for (int i = 0; i < fnbr.Size(); i++)
   {
      NCMesh::Element* elem = fnbr[i];
      mfem::Element* fne = NewMeshElement(elem->geom);
      fne->SetAttribute(elem->attribute);
      pmesh.face_nbr_elements.Append(fne);

      GeomInfo& gi = GI[(int) elem->geom];
      for (int k = 0; k < gi.nv; k++)
      {
         int &v = vert_map[elem->node[k]];
         if (!v) { v = static_cast<int>(vert_map.size()); }
         fne->GetVertices()[k] = v-1;
      }

      if (!i || elem->rank != fnbr[i-1]->rank)
      {
         pmesh.face_nbr_elements_offset.Append(i);
      }

      MFEM_ASSERT(elem->index >= NElements, "not a ghost element");
      fnbr_index[elem->index - NElements] = i;
   }
   pmesh.face_nbr_elements_offset.Append(fnbr.Size());

   // create vertices in 'face_nbr_vertices'
   {
      pmesh.face_nbr_vertices.SetSize(static_cast<int>(vert_map.size()));
      if (coordinates.Size())
      {
         tmp_vertex = new TmpVertex[nodes.NumIds()]; // TODO: something cheaper?
         for (const auto &v : vert_map)
         {
            pmesh.face_nbr_vertices[v.second-1].SetCoords(
               spaceDim, CalcVertexPos(v.first));
         }
         delete [] tmp_vertex;
      }
   }

   // make the 'send_face_nbr_elements' table
   send_elems.Sort();
   send_elems.Unique();

   for (auto &kv : recv_elems)
   {
      std::sort(kv.second.begin(), kv.second.end());
      kv.second.erase(std::unique(kv.second.begin(), kv.second.end()),
                      kv.second.end());
   }

   for (int i = 0, last_rank = -1; i < send_elems.Size(); i++)
   {
      Connection &c = send_elems[i];
      if (c.from != last_rank)
      {
         // renumber rank to position in 'face_nbr_group'
         last_rank = c.from;
         c.from = pmesh.face_nbr_group.Find(c.from);
      }
      else
      {
         c.from = send_elems[i-1].from; // avoid search
      }
   }
   pmesh.send_face_nbr_elements.MakeFromList(nranks, send_elems);

   // go over the shared faces again and modify their Mesh::FaceInfo
   for (const auto& cf : shared.conforming)
   {
      Face* face = GetFace(elements[cf.element], cf.local);
      Element* e[2] = { &elements[face->elem[0]], &elements[face->elem[1]] };
      if (e[0]->rank == MyRank) { std::swap(e[0], e[1]); }

      Mesh::FaceInfo &fi = pmesh.faces_info[cf.index];
      fi.Elem2No = -1 - fnbr_index[e[0]->index - NElements];

      if (Dim == 3)
      {
         int local[2];
         int o = get_face_orientation(*face, *e[1], *e[0], local);
         fi.Elem2Inf = 64*local[1] + o;
      }
      else
      {
         fi.Elem2Inf = 64*find_element_edge(*e[0], face->p1, face->p3) + 1;
      }
   }

   // If there are shared slaves, they will also need to be updated.
   if (shared.slaves.Size())
   {
      int nfaces = NFaces, nghosts = NGhostFaces;
      if (Dim <= 2) { nfaces = NEdges, nghosts = NGhostEdges; }

      // enlarge Mesh::faces_info for ghost slaves
      MFEM_ASSERT(pmesh.faces_info.Size() == nfaces, "");
      MFEM_ASSERT(pmesh.GetNumFaces() == nfaces, "");
      pmesh.faces_info.SetSize(nfaces + nghosts);
      for (int i = nfaces; i < pmesh.faces_info.Size(); i++)
      {
         Mesh::FaceInfo &fi = pmesh.faces_info[i];
         fi.Elem1No  = fi.Elem2No  = -1;
         fi.Elem1Inf = fi.Elem2Inf = -1;
         fi.NCFace = -1;
      }
      // Note that some of the indices i >= nfaces in pmesh.faces_info will
      // remain untouched below and they will have Elem1No == -1, in particular.

      // fill in FaceInfo for shared slave faces
      for (int i = 0; i < shared.masters.Size(); i++)
      {
         const Master &mf = shared.masters[i];
         for (int j = mf.slaves_begin; j < mf.slaves_end; j++)
         {
            const Slave &sf = full_list.slaves[j];
            if (sf.element < 0) { continue; }

            MFEM_ASSERT(mf.element >= 0, "");
            Element &sfe = elements[sf.element];
            Element &mfe = elements[mf.element];

            bool sloc = (sfe.rank == MyRank);
            bool mloc = (mfe.rank == MyRank);
            if (sloc == mloc // both or neither face is owned by this processor
                || sf.index < 0) // the face is degenerate (i.e. a edge-face constraint)
            {
               continue;
            }

            // This is a genuine slave face, the info associated with it must
            // be updated.
            Mesh::FaceInfo &fi = pmesh.faces_info[sf.index];
            fi.Elem1No = sfe.index;
            fi.Elem2No = mfe.index;
            fi.Elem1Inf = 64 * sf.local;
            fi.Elem2Inf = 64 * mf.local;

            if (!sloc)
            {
               // 'fi' is the info for a ghost slave face with index:
               // sf.index >= nfaces
               std::swap(fi.Elem1No, fi.Elem2No);
               std::swap(fi.Elem1Inf, fi.Elem2Inf);
               // After the above swap, Elem1No refers to the local, master-side
               // element. In other words, side 1 IS NOT the side that generated
               // the face.
            }
            else
            {
               // 'fi' is the info for a local slave face with index:
               // sf.index < nfaces
               // Here, Elem1No refers to the local, slave-side element.
               // In other words, side 1 IS the side that generated the face.
            }
            MFEM_ASSERT(fi.Elem2No >= NElements, "");
            fi.Elem2No = -1 - fnbr_index[fi.Elem2No - NElements];

            const DenseMatrix* pm = full_list.point_matrices[sf.geom][sf.matrix];
            if (!sloc && Dim == 3)
            {
               // ghost slave in 3D needs flipping orientation
               DenseMatrix* pm2 = new DenseMatrix(*pm);
               if (sf.geom == Geometry::Type::SQUARE)
               {
                  std::swap((*pm2)(0, 1), (*pm2)(0, 3));
                  std::swap((*pm2)(1, 1), (*pm2)(1, 3));
               }
               else if (sf.geom == Geometry::Type::TRIANGLE)
               {
                  std::swap((*pm2)(0, 0), (*pm2)(0, 1));
                  std::swap((*pm2)(1, 0), (*pm2)(1, 1));
               }
               aux_pm_store.Append(pm2);

               fi.Elem2Inf ^= 1;
               pm = pm2;

               // The problem is that sf.point_matrix is designed for P matrix
               // construction and always has orientation relative to the slave
               // face. In ParMesh::GetSharedFaceTransformations the result
               // would therefore be the same on both processors, which is not
               // how that function works for conforming faces. The orientation
               // of Loc1, Loc2 and Face needs to always be relative to Element
               // 1, which is the element containing the slave face on one
               // processor, but on the other it is the element containing the
               // master face. In the latter case we need to flip the pm.
            }
            else if (!sloc && Dim == 2)
            {
               fi.Elem2Inf ^= 1; // set orientation to 1
               // The point matrix (used to define "side 1" which is the same as
               // "parent side" in this case) does not require a flip since it
               // is aligned with the parent side, so NO flip is performed in
               // Mesh::ApplyLocalSlaveTransformation.
            }

            MFEM_ASSERT(fi.NCFace < 0, "fi.NCFace = " << fi.NCFace);
            fi.NCFace = pmesh.nc_faces_info.Size();
            pmesh.nc_faces_info.Append(Mesh::NCFaceInfo(true, sf.master, pm));
         }
      }
   }


   // In 3D some extra orientation data structures can be needed.
   if (Dim == 3)
   {
      // Populates face_nbr_el_to_face, always needed.
      pmesh.BuildFaceNbrElementToFaceTable();

      if (face_nbr_w_tri_faces)
      {
         // There are face neighbor elements with triangular faces, need to
         // perform communication to ensure the orientation is valid.
         using RankToOrientation = std::map<int, std::vector<std::array<int, 6>>>;
         constexpr std::array<int, 6> unset_ori{{-1,-1,-1,-1,-1,-1}};
         const int rank = pmesh.GetMyRank();

         // Loop over send elems, compute the orientation and place in the
         // buffer to send to each processor. Note elements are
         // lexicographically sorted with rank and element number, and this
         // ordering holds across processors.
         RankToOrientation send_rank_to_face_neighbor_orientations;
         Array<int> orientations, faces;

         // send_elems goes from rank of the receiving processor, to the index
         // of the face neighbor element on this processor.
         for (const auto &se : send_elems)
         {
            const auto &true_rank = pmesh.face_nbr_group[se.from];
            pmesh.GetElementFaces(se.to, faces, orientations);

            // Place a new entry of unset orientations
            send_rank_to_face_neighbor_orientations[true_rank].emplace_back(unset_ori);

            // Copy the entries, any unset faces will remain -1.
            std::copy(orientations.begin(), orientations.end(),
                      send_rank_to_face_neighbor_orientations[true_rank].back().begin());
         }

         // Initialize the receive buffers and resize to match the expected
         // number of elements coming in. The copy ensures the appropriate rank
         // pairings are in place, and for a purely conformal interface, the
         // resize is a no-op.
         auto recv_rank_to_face_neighbor_orientations =
            send_rank_to_face_neighbor_orientations;
         for (auto &kv : recv_rank_to_face_neighbor_orientations)
         {
            kv.second.resize(recv_elems[kv.first].size());
         }

         // For asynchronous send/recv, will use arrays of requests to monitor the
         // status of the connections.
         std::vector<MPI_Request> send_requests, recv_requests;
         std::vector<MPI_Status> status(nranks);

         // NOTE: This is CRITICAL, to ensure the addresses of these requests
         // do not change between the send/recv and the wait.
         send_requests.reserve(nranks);
         recv_requests.reserve(nranks);

         // Shared face communication is bidirectional -> any rank to whom
         // orientations must be sent, will need to send orientations back. The
         // orientation data is contiguous because std::array<int,6> is an
         // aggregate. Loop over each communication pairing, and dispatch the
         // buffer loaded with  all the orientation data.
         for (const auto &kv : send_rank_to_face_neighbor_orientations)
         {
            send_requests.emplace_back(); // instantiate a request for tracking.

            // low rank sends on low, high rank sends on high.
            const int send_tag = (rank < kv.first)
                                 ? std::min(rank, kv.first)
                                 : std::max(rank, kv.first);
            MPI_Isend(&kv.second[0][0], int(kv.second.size() * 6),
                      MPI_INT, kv.first, send_tag, pmesh.MyComm, &send_requests.back());
         }

         // Loop over the communication pairing again, and receive the
         // symmetric buffer from the other processor.
         for (auto &kv : recv_rank_to_face_neighbor_orientations)
         {
            recv_requests.emplace_back(); // instantiate a request for tracking

            // low rank receives on high, high rank receives on low.
            const int recv_tag = (rank < kv.first)
                                 ? std::max(rank, kv.first)
                                 : std::min(rank, kv.first);
            MPI_Irecv(&kv.second[0][0], int(kv.second.size() * 6),
                      MPI_INT, kv.first, recv_tag, pmesh.MyComm, &recv_requests.back());
         }

         // Wait until all receive buffers are full before beginning to process.
         MPI_Waitall(int(recv_requests.size()), recv_requests.data(), status.data());

         pmesh.face_nbr_el_ori.reset(new Table(pmesh.face_nbr_elements.Size(), 6));
         int elem = 0;
         for (const auto &kv : recv_rank_to_face_neighbor_orientations)
         {
            // All elements associated to this face-neighbor rank
            for (const auto &eo : kv.second)
            {
               std::copy(eo.begin(), eo.end(), pmesh.face_nbr_el_ori->GetRow(elem));
               ++elem;
            }
         }
         pmesh.face_nbr_el_ori->Finalize();

         // Must wait for all send buffers to be released before the scope closes.
         MPI_Waitall(int(send_requests.size()), send_requests.data(), status.data());
      }
   }
   // NOTE: this function skips ParMesh::send_face_nbr_vertices and
   // ParMesh::face_nbr_vertices_offset, these are not used outside of ParMesh
}

void ParNCMesh::ClearAuxPM()
{
   for (int i = 0; i < aux_pm_store.Size(); i++)
   {
      delete aux_pm_store[i];
   }
   aux_pm_store.DeleteAll();
}

//// Prune, Refine, Derefine ///////////////////////////////////////////////////

bool ParNCMesh::PruneTree(int elem)
{
   Element &el = elements[elem];
   if (el.ref_type)
   {
      bool remove[8];
      bool removeAll = true;

      // determine which subtrees can be removed (and whether it's all of them)
      for (int i = 0; i < 8; i++)
      {
         remove[i] = false;
         if (el.child[i] >= 0)
         {
            remove[i] = PruneTree(el.child[i]);
            if (!remove[i]) { removeAll = false; }
         }
      }

      // all children can be removed, let the (maybe indirect) parent do it
      if (removeAll) { return true; }

      // not all children can be removed, but remove those that can be
      for (int i = 0; i < 8; i++)
      {
         if (remove[i]) { DerefineElement(el.child[i]); }
      }

      return false; // need to keep this element and up
   }
   else
   {
      // return true if this leaf can be removed
      return el.rank < 0;
   }
}

void ParNCMesh::Prune()
{
   if (!Iso && Dim == 3)
   {
      if (MyRank == 0)
      {
         MFEM_WARNING("Can't prune 3D aniso meshes yet.");
      }
      return;
   }

   UpdateLayers();

   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      // rank of elements beyond the ghost layer is unknown / not updated
      if (element_type[i] == 0)
      {
         elements[leaf_elements[i]].rank = -1;
         // NOTE: rank == -1 will make the element disappear from leaf_elements
         // on next Update, see NCMesh::CollectLeafElements
      }
   }

   // derefine subtrees whose leaves are all unneeded
   for (int i = 0; i < root_state.Size(); i++)
   {
      if (PruneTree(i)) { DerefineElement(i); }
   }

   Update();
}

bool ParNCMesh::AnisotropicConflict(const Array<Refinement> &refinements,
                                    std::set<int> &conflicts)
{
   if (Dim < 3 || NRanks == 1) { return false; }

   for (int i = 0; i < refinements.Size() && Iso; i++)
   {
      const Refinement &ref = refinements[i];
      if (ref.GetType() != Refinement::XYZ)
      {
         Iso = false;
      }
   }

   // Reduce the Iso flag over all MPI ranks.
   bool globalIso = false;
   MPI_Allreduce(&Iso, &globalIso, 1, MFEM_MPI_CXX_BOOL, MPI_LAND, MyComm);

   if (globalIso) { return false; }

   // In the 3D parallel anisotropic case, check for conflicts on faces.
   NeighborRefinementMessage::Map send_ref;

   // Create refinement messages to all neighbors (NOTE: some may be empty).
   Array<int> neighbors;
   NeighborProcessors(neighbors);
   for (int i = 0; i < neighbors.Size(); i++)
   {
      send_ref[neighbors[i]].SetNCMesh(this);
   }

   // Populate messages: all refinements that occur next to the processor
   // boundary need to be sent to the adjoining neighbors so they can keep
   // their ghost layer up to date.
   Array<int> ranks;
   ranks.Reserve(64);
   for (int i = 0; i < refinements.Size(); i++)
   {
      const Refinement &ref = refinements[i];
      MFEM_ASSERT(ref.index < NElements, "");
      const int elem = leaf_elements[ref.index];
      ElementNeighborProcessors(elem, ranks);
      for (int j = 0; j < ranks.Size(); j++)
      {
         send_ref[ranks[j]].AddRefinement(elem, ref.GetType());
      }
   }

   // Send the messages (overlap with local refinements)
   NeighborRefinementMessage::IsendAll(send_ref, MyComm);

   // Note that ghost refinements are not looked up using elemToRef. Local
   // refinements are recorded first in elemToRef, and ghosts only need to be
   // compared to local refinements. There is no need for ghost-to-ghost
   // comparisons.
   std::map<int, int> elemToRef; // Only for local refinements, not ghosts.
   for (int i = 0; i < refinements.Size(); i++)
   {
      elemToRef[leaf_elements[refinements[i].index]] = i;
   }

   // Check local refinements
   for (int i = 0; i < refinements.Size(); i++)
   {
      const Refinement &ref = refinements[i];
      CheckRefinement(leaf_elements[ref.index], ref.GetType(), refinements,
                      elemToRef, conflicts);
   }

   // Receive (ghost layer) refinements from all neighbors
   for (int j = 0; j < neighbors.Size(); j++)
   {
      int rank, size;
      NeighborRefinementMessage::Probe(rank, size, MyComm);

      NeighborRefinementMessage msg;
      msg.SetNCMesh(this);
      msg.Recv(rank, size, MyComm);

      // check the ghost refinements
      for (int i = 0; i < msg.Size(); i++)
      {
         CheckRefinement(msg.elements[i], msg.values[i], refinements, elemToRef,
                         conflicts);
      }
   }

   // Make sure we can delete the send buffers
   NeighborRefinementMessage::WaitAllSent(send_ref);

   CheckRefinementMaster(refinements, elemToRef, conflicts);

   const bool conflict = conflicts.size() > 0;
   bool globalConflict = false;
   MPI_Allreduce(&conflict, &globalConflict, 1, MFEM_MPI_CXX_BOOL, MPI_LOR,
                 MyComm);
   return globalConflict;
}

int GetHexFaceDir(int face)
{
   // Hexahedron face vertices
   // From Geometry::Constants<Geometry::CUBE>::FaceVert[6][4] in fem/geom.cpp
   // {3, 2, 1, 0}, {0, 1, 5, 4}, {1, 2, 6, 5},
   // {2, 3, 7, 6}, {3, 0, 4, 7}, {4, 5, 6, 7}
   constexpr std::array<int, 6> hexFaceDir = {2, 1, 0, 1, 0, 2};
   return hexFaceDir[face];
}

char GetHexFaceRefType(const bool (&refDir)[3], int face)
{
   const int faceDir = GetHexFaceDir(face);
   std::array<int, 2> faceRefDir;
   int cnt = 0;
   for (int d=0; d<3; ++d)
   {
      if (d != faceDir)
      {
         faceRefDir[cnt] = refDir[d] ? 1 : 0;
         cnt++;
      }
   }

   const char ref_type = (char)(faceRefDir[0] + (2 * faceRefDir[1]));
   return ref_type;
}

// Assuming a vertical split of the master face with ordered vertices
// (vn1, vn2, vn3, vn4), check whether there is a horizontal split among the
// slave faces of this face. This recursive function is similar to
// NCMesh::CheckAnisoFace.
bool ParNCMesh::CheckRefAnisoFaceSplits(int vn1, int vn2, int vn3, int vn4,
                                        int level)
{
   const int mid23 = FindMidEdgeNode(vn2, vn3);
   const int mid41 = FindMidEdgeNode(vn4, vn1);

   if (mid23 >= 0 && mid41 >= 0) // If horizontally split
   {
      const int midf = nodes.FindId(mid23, mid41);
      if (midf >= 0)
      {
         if (CheckRefAnisoFaceSplits(vn1, vn2, mid23, mid41, level + 1))
         {
            return true;
         }
         if (CheckRefAnisoFaceSplits(mid41, mid23, vn3, vn4, level + 1))
         {
            return true;
         }
      }
   }

   if (level > 0) { return true; }

   return false;
}

void ParNCMesh::CheckRefinementMaster(const Array<Refinement> &refinements,
                                      const std::map<int, int> &elemToRef,
                                      std::set<int> &conflicts)
{
   MFEM_VERIFY(Dim == 3, "");
   const NCList &faceList = GetFaceList();

   for (const auto &mf : faceList.masters)
   {
      // Check for conflicts only if the master element is marked for refinement
      if (elemToRef.count(mf.element) == 0) { continue; }

      const int refIndex = elemToRef.at(mf.element);
      const Refinement& ref = refinements[refIndex];

      bool refDir[3];
      for (int i=0; i<3; ++i)
         refDir[i] = ref.s[i] > real_t{0};

      const char faceRefType = GetHexFaceRefType(refDir, mf.local);
      if (faceRefType == 0) { continue; } // No refinement on this face

      std::array<int, 4> fv;
      for (int i=0; i<4; ++i)
      {
         fv[i] = elements[mf.element].node[
                    Geometry::Constants<Geometry::CUBE>::FaceVert[mf.local][i]];
      }

      if (faceRefType != 2) // X or XY split w.r.t. the face.
      {
         // Check X face split
         if (CheckRefAnisoFaceSplits(fv[0], fv[1], fv[2], fv[3]))
         {
            conflicts.insert(refIndex);
         }
      }

      if (faceRefType != 1) // Y or XY split w.r.t. the face.
      {
         // Check Y face split
         if (CheckRefAnisoFaceSplits(fv[1], fv[2], fv[3], fv[0]))
         {
            conflicts.insert(refIndex);
         }
      }
   }
}

int FindHexFace(const int* no, int vn1, int vn2, int vn3, int vn4)
{
   std::set<int> v;
   v.insert({vn1, vn2, vn3, vn4});

   int face = -1;
   for (int f=0; f<6; ++f)
   {
      bool allFound = true;
      for (int i=0; i<4; ++i)
      {
         const int vi = no[Geometry::Constants<Geometry::CUBE>::FaceVert[f][i]];
         if (v.count(vi) == 0)
         {
            allFound = false;
         }
      }

      if (allFound)
      {
         MFEM_ASSERT(face == -1, "");
         face = f;
      }
   }

   MFEM_ASSERT(face >= 0, "");
   return face;
}

// Assumption: v1 and v2 are indices of hex vertices connected by an edge.
// The return value is {0,1,2} denoting split {X,Y,Z}.
int GetHexEdgeSplit(const int* nodes, int v1, int v2)
{
   Array<int> v(2);
   v[0] = v1;
   v[1] = v2;
   v.Sort();

   // Find the edge in the hexahedron
   int edge = -1;
   Array<int> ev(2);
   for (int i=0; i<12; ++i)
   {
      for (int j=0; j<2; ++j)
      {
         ev[j] = nodes[Geometry::Constants<Geometry::CUBE>::Edges[i][j]];
      }
      ev.Sort();

      if (ev == v)
      {
         MFEM_ASSERT(edge == -1, "");
         edge = i;
      }
   }

   MFEM_ASSERT(edge >= 0, "");

   constexpr int edgeDir[12] = {0, 1, 0, 1, 0, 1, 0, 1, 2, 2, 2, 2};
   return edgeDir[edge];
}

void ParNCMesh::CheckRefAnisoFace(int elem, int vn1, int vn2, int vn3, int vn4,
                                  const Array<Refinement> &refinements,
                                  const std::map<int, int> &elemToRef,
                                  std::set<int> &conflicts)
{
   Face* face = faces.Find(vn1, vn2, vn3, vn4);
   if (!face) { return; }

   // Find the neighbor of this face.
   const int nghbIndex = face->elem[0] == elem ? face->elem[1] : face->elem[0];
   if (nghbIndex < 0) { return; }

   Element &nghb = elements[nghbIndex];
   MFEM_ASSERT(nghb.ref_type == 0, "");

   if (elemToRef.count(nghbIndex) > 0)
   {
      const int refIndex = elemToRef.at(nghbIndex);
      const Refinement& ref = refinements[refIndex];

      bool refDir[3];
      for (int i=0; i<3; ++i)
         refDir[i] = ref.s[i] > real_t{0};

      const int localFace = FindHexFace(nghb.node, vn1, vn2, vn3, vn4);
      const int faceDir = GetHexFaceDir(localFace);
      const char face_ref_type = GetHexFaceRefType(refDir, localFace);
      const bool faceAniso = face_ref_type == 1 ||
                             face_ref_type == 2; // X or Y w.r.t. the face.

      if (faceAniso)
      {
         // Determine whether the face is anisotropically split in the vertical
         // direction, with respect to the vertex ordering (vn1, vn2, vn3, vn4).
         int hexSplitOnFace = -1;

         const int firstFaceDir = face_ref_type == 1 ? 0 : 1;

         int cnt = 0;
         for (int i=0; i<3; ++i)
         {
            if (i == faceDir) { continue; }

            if (firstFaceDir == cnt)
            {
               MFEM_ASSERT(hexSplitOnFace == -1, "");
               hexSplitOnFace = i;
            }

            cnt++;
         }
         MFEM_ASSERT(cnt == 2 && hexSplitOnFace >= 0, "");

         const int edgeSplit = GetHexEdgeSplit(nghb.node, vn1, vn2);
         if (edgeSplit != hexSplitOnFace) { conflicts.insert(refIndex); }
      }
   }
   // The else case is that the neighbor is not refined, so there is no need to
   // check for conflicts.
}

void ParNCMesh::CheckRefIsoFace(int elem, int vn1, int vn2, int vn3, int vn4,
                                int en1, int en2, int en3, int en4,
                                const Array<Refinement> &refinements,
                                const std::map<int, int> &elemToRef,
                                std::set<int> &conflicts)
{
   CheckRefAnisoFace(elem, vn1, vn2, en2, en4, refinements, elemToRef, conflicts);
   CheckRefAnisoFace(elem, en4, en2, vn3, vn4, refinements, elemToRef, conflicts);
   CheckRefAnisoFace(elem, vn4, vn1, en1, en3, refinements, elemToRef, conflicts);
   CheckRefAnisoFace(elem, en3, en1, vn2, vn3, refinements, elemToRef, conflicts);
}

void ParNCMesh::CheckRefinement(int elem, char ref_type,
                                const Array<Refinement> &refinements,
                                const std::map<int, int> &elemToRef,
                                std::set<int> &conflicts)
{
   const Element &el = elements[elem];
   MFEM_ASSERT(el.geom == Geometry::CUBE && el.ref_type == 0,
               "Element must be an unrefined hexahedron");

   const int* no = el.node;

   // Check the faces of this element being refined (depends on ref_type).
   // This follows the logic of NCMesh::RefineElement().
   if (ref_type == Refinement::X) // split along X axis
   {
      CheckRefAnisoFace(elem, no[0], no[1], no[5], no[4], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[2], no[3], no[7], no[6], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[4], no[5], no[6], no[7], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[3], no[2], no[1], no[0], refinements,
                        elemToRef, conflicts);
   }
   else if (ref_type == Refinement::Y) // split along Y axis
   {
      CheckRefAnisoFace(elem, no[1], no[2], no[6], no[5], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[3], no[0], no[4], no[7], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[5], no[6], no[7], no[4], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[0], no[3], no[2], no[1], refinements,
                        elemToRef, conflicts);
   }
   else if (ref_type == Refinement::Z) // split along Z axis
   {
      CheckRefAnisoFace(elem, no[4], no[0], no[1], no[5], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[5], no[1], no[2], no[6], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[6], no[2], no[3], no[7], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[7], no[3], no[0], no[4], refinements,
                        elemToRef, conflicts);
   }
   else if (ref_type == Refinement::XY) // XY split
   {
      CheckRefAnisoFace(elem, no[0], no[1], no[5], no[4], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[1], no[2], no[6], no[5], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[2], no[3], no[7], no[6], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[3], no[0], no[4], no[7], refinements,
                        elemToRef, conflicts);

      const int mid01 = GetMidEdgeNode(no[0], no[1]);
      const int mid12 = GetMidEdgeNode(no[1], no[2]);
      const int mid23 = GetMidEdgeNode(no[2], no[3]);
      const int mid30 = GetMidEdgeNode(no[3], no[0]);

      const int mid45 = GetMidEdgeNode(no[4], no[5]);
      const int mid56 = GetMidEdgeNode(no[5], no[6]);
      const int mid67 = GetMidEdgeNode(no[6], no[7]);
      const int mid74 = GetMidEdgeNode(no[7], no[4]);

      CheckRefIsoFace(elem, no[3], no[2], no[1], no[0], mid23, mid12, mid01,
                      mid30, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[4], no[5], no[6], no[7], mid45, mid56, mid67,
                      mid74, refinements, elemToRef, conflicts);
   }
   else if (ref_type == Refinement::XZ) // XZ split
   {
      CheckRefAnisoFace(elem, no[3], no[2], no[1], no[0], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[2], no[6], no[5], no[1], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[6], no[7], no[4], no[5], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[7], no[3], no[0], no[4], refinements,
                        elemToRef, conflicts);

      const int mid01 = GetMidEdgeNode(no[0], no[1]);
      const int mid23 = GetMidEdgeNode(no[2], no[3]);
      const int mid45 = GetMidEdgeNode(no[4], no[5]);
      const int mid67 = GetMidEdgeNode(no[6], no[7]);

      const int mid04 = GetMidEdgeNode(no[0], no[4]);
      const int mid15 = GetMidEdgeNode(no[1], no[5]);
      const int mid26 = GetMidEdgeNode(no[2], no[6]);
      const int mid37 = GetMidEdgeNode(no[3], no[7]);

      CheckRefIsoFace(elem, no[0], no[1], no[5], no[4], mid01, mid15, mid45,
                      mid04, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[2], no[3], no[7], no[6], mid23, mid37, mid67,
                      mid26, refinements, elemToRef, conflicts);
   }
   else if (ref_type == Refinement::YZ) // YZ split
   {
      const int mid12 = GetMidEdgeNode(no[1], no[2]);
      const int mid30 = GetMidEdgeNode(no[3], no[0]);
      const int mid56 = GetMidEdgeNode(no[5], no[6]);
      const int mid74 = GetMidEdgeNode(no[7], no[4]);

      const int mid04 = GetMidEdgeNode(no[0], no[4]);
      const int mid15 = GetMidEdgeNode(no[1], no[5]);
      const int mid26 = GetMidEdgeNode(no[2], no[6]);
      const int mid37 = GetMidEdgeNode(no[3], no[7]);

      CheckRefAnisoFace(elem, no[4], no[0], no[1], no[5], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[0], no[3], no[2], no[1], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[3], no[7], no[6], no[2], refinements,
                        elemToRef, conflicts);
      CheckRefAnisoFace(elem, no[7], no[4], no[5], no[6], refinements,
                        elemToRef, conflicts);

      CheckRefIsoFace(elem, no[1], no[2], no[6], no[5], mid12, mid26, mid56,
                      mid15, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[3], no[0], no[4], no[7], mid30, mid04, mid74,
                      mid37, refinements, elemToRef, conflicts);
   }
   else if (ref_type == Refinement::XYZ) // XYZ split
   {
      const int mid01 = GetMidEdgeNode(no[0], no[1]);
      const int mid12 = GetMidEdgeNode(no[1], no[2]);
      const int mid23 = GetMidEdgeNode(no[2], no[3]);
      const int mid30 = GetMidEdgeNode(no[3], no[0]);

      const int mid45 = GetMidEdgeNode(no[4], no[5]);
      const int mid56 = GetMidEdgeNode(no[5], no[6]);
      const int mid67 = GetMidEdgeNode(no[6], no[7]);
      const int mid74 = GetMidEdgeNode(no[7], no[4]);

      const int mid04 = GetMidEdgeNode(no[0], no[4]);
      const int mid15 = GetMidEdgeNode(no[1], no[5]);
      const int mid26 = GetMidEdgeNode(no[2], no[6]);
      const int mid37 = GetMidEdgeNode(no[3], no[7]);

      CheckRefIsoFace(elem, no[3], no[2], no[1], no[0], mid23, mid12, mid01,
                      mid30, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[0], no[1], no[5], no[4], mid01, mid15, mid45,
                      mid04, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[1], no[2], no[6], no[5], mid12, mid26, mid56,
                      mid15, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[2], no[3], no[7], no[6], mid23, mid37, mid67,
                      mid26, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[3], no[0], no[4], no[7], mid30, mid04, mid74,
                      mid37, refinements, elemToRef, conflicts);
      CheckRefIsoFace(elem, no[4], no[5], no[6], no[7], mid45, mid56, mid67,
                      mid74, refinements, elemToRef, conflicts);
   }
   else
   {
      MFEM_ABORT("Invalid refinement type.");
   }
}

void ParNCMesh::Refine(const Array<Refinement> &refinements)
{
   if (NRanks == 1)
   {
      NCMesh::Refine(refinements);
      return;
   }

   for (int i = 0; i < refinements.Size() && Iso; i++)
   {
      const Refinement &ref = refinements[i];
      if (ref.GetType() != Refinement::XYZ)
      {
         Iso = false;
      }
   }

   NeighborRefinementMessage::Map send_ref;

   // create refinement messages to all neighbors (NOTE: some may be empty)
   Array<int> neighbors;
   NeighborProcessors(neighbors);
   for (int i = 0; i < neighbors.Size(); i++)
   {
      send_ref[neighbors[i]].SetNCMesh(this);
   }

   // populate messages: all refinements that occur next to the processor
   // boundary need to be sent to the adjoining neighbors so they can keep
   // their ghost layer up to date
   Array<int> ranks;
   ranks.Reserve(64);
   for (int i = 0; i < refinements.Size(); i++)
   {
      const Refinement &ref = refinements[i];
      MFEM_ASSERT(ref.index < NElements, "");
      const int elem = leaf_elements[ref.index];
      ElementNeighborProcessors(elem, ranks);
      for (int j = 0; j < ranks.Size(); j++)
      {
         send_ref[ranks[j]].AddRefinement(elem, ref.GetType());
      }
   }

   // send the messages (overlap with local refinements)
   NeighborRefinementMessage::IsendAll(send_ref, MyComm);

   // do local refinements
   for (int i = 0; i < refinements.Size(); i++)
   {
      const Refinement &ref = refinements[i];
      NCMesh::RefineElement(leaf_elements[ref.index], ref.GetType());
   }

   // receive (ghost layer) refinements from all neighbors
   for (int j = 0; j < neighbors.Size(); j++)
   {
      int rank, size;
      NeighborRefinementMessage::Probe(rank, size, MyComm);

      NeighborRefinementMessage msg;
      msg.SetNCMesh(this);
      msg.Recv(rank, size, MyComm);

      // do the ghost refinements
      for (int i = 0; i < msg.Size(); i++)
      {
         NCMesh::RefineElement(msg.elements[i], msg.values[i]);
      }
   }

   Update();

   // make sure we can delete the send buffers
   NeighborRefinementMessage::WaitAllSent(send_ref);
}


void ParNCMesh::LimitNCLevel(int max_nc_level)
{
   MFEM_VERIFY(max_nc_level >= 1, "'max_nc_level' must be 1 or greater.");

   while (1)
   {
      Array<Refinement> refinements;
      GetLimitRefinements(refinements, max_nc_level);

      long long size = refinements.Size(), glob_size;
      MPI_Allreduce(&size, &glob_size, 1, MPI_LONG_LONG, MPI_SUM, MyComm);

      if (!glob_size) { break; }

      Refine(refinements);
   }
}

void ParNCMesh::GetFineToCoarsePartitioning(const Array<int> &derefs,
                                            Array<int> &new_ranks) const
{
   new_ranks.SetSize(leaf_elements.Size()-GetNGhostElements());
   for (int i = 0; i < leaf_elements.Size()-GetNGhostElements(); i++)
   {
      new_ranks[i] = elements[leaf_elements[i]].rank;
   }

   for (int i = 0; i < derefs.Size(); i++)
   {
      int row = derefs[i];
      MFEM_VERIFY(row >= 0 && row < derefinements.Size(),
                  "invalid derefinement number.");

      const int* fine = derefinements.GetRow(row);
      int size = derefinements.RowSize(row);

      int coarse_rank = INT_MAX;
      for (int j = 0; j < size; j++)
      {
         int fine_rank = elements[leaf_elements[fine[j]]].rank;
         coarse_rank = std::min(coarse_rank, fine_rank);
      }
      for (int j = 0; j < size; j++)
      {
         new_ranks[fine[j]] = coarse_rank;
      }
   }
}

void ParNCMesh::Derefine(const Array<int> &derefs)
{
   MFEM_VERIFY(Dim < 3 || Iso,
               "derefinement of 3D anisotropic meshes not implemented yet.");

   InitDerefTransforms();

   // store fine element ranks
   old_index_or_rank.SetSize(leaf_elements.Size());
   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      old_index_or_rank[i] = elements[leaf_elements[i]].rank;
   }

   // back up the leaf_elements array
   Array<int> old_elements;
   leaf_elements.Copy(old_elements);

   // *** STEP 1: redistribute elements to avoid complex derefinements ***

   Array<int> new_ranks(leaf_elements.Size());
   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      new_ranks[i] = elements[leaf_elements[i]].rank;
   }

   // make the lowest rank get all the fine elements for each derefinement
   for (int i = 0; i < derefs.Size(); i++)
   {
      int row = derefs[i];
      MFEM_VERIFY(row >= 0 && row < derefinements.Size(),
                  "invalid derefinement number.");

      const int* fine = derefinements.GetRow(row);
      int size = derefinements.RowSize(row);

      int coarse_rank = INT_MAX;
      for (int j = 0; j < size; j++)
      {
         int fine_rank = elements[leaf_elements[fine[j]]].rank;
         coarse_rank = std::min(coarse_rank, fine_rank);
      }
      for (int j = 0; j < size; j++)
      {
         new_ranks[fine[j]] = coarse_rank;
      }
   }

   int target_elements = 0;
   for (int i = 0; i < new_ranks.Size(); i++)
   {
      if (new_ranks[i] == MyRank) { target_elements++; }
   }

   // redistribute elements slightly to get rid of complex derefinements
   // straddling processor boundaries *and* update the ghost layer
   RedistributeElements(new_ranks, target_elements, false);

   // *** STEP 2: derefine now, communication similar to Refine() ***

   NeighborDerefinementMessage::Map send_deref;

   // create derefinement messages to all neighbors (NOTE: some may be empty)
   Array<int> neighbors;
   NeighborProcessors(neighbors);
   for (int i = 0; i < neighbors.Size(); i++)
   {
      send_deref[neighbors[i]].SetNCMesh(this);
   }

   // derefinements that occur next to the processor boundary need to be sent
   // to the adjoining neighbors to keep their ghost layers in sync
   Array<int> ranks;
   ranks.Reserve(64);
   for (int i = 0; i < derefs.Size(); i++)
   {
      const int* fine = derefinements.GetRow(derefs[i]);
      int parent = elements[old_elements[fine[0]]].parent;

      // send derefinement to neighbors
      ElementNeighborProcessors(parent, ranks);
      for (int j = 0; j < ranks.Size(); j++)
      {
         send_deref[ranks[j]].AddDerefinement(parent, new_ranks[fine[0]]);
      }
   }
   NeighborDerefinementMessage::IsendAll(send_deref, MyComm);

   // restore old (pre-redistribution) element indices, for SetDerefMatrixCodes
   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      elements[leaf_elements[i]].index = -1;
   }
   for (int i = 0; i < old_elements.Size(); i++)
   {
      elements[old_elements[i]].index = i;
   }

   // do local derefinements
   Array<int> coarse;
   old_elements.Copy(coarse);
   for (int i = 0; i < derefs.Size(); i++)
   {
      const int* fine = derefinements.GetRow(derefs[i]);
      int parent = elements[old_elements[fine[0]]].parent;

      // record the relation of the fine elements to their parent
      SetDerefMatrixCodes(parent, coarse);

      NCMesh::DerefineElement(parent);
   }

   // receive ghost layer derefinements from all neighbors
   for (int j = 0; j < neighbors.Size(); j++)
   {
      int rank, size;
      NeighborDerefinementMessage::Probe(rank, size, MyComm);

      NeighborDerefinementMessage msg;
      msg.SetNCMesh(this);
      msg.Recv(rank, size, MyComm);

      // do the ghost derefinements
      for (int i = 0; i < msg.Size(); i++)
      {
         int elem = msg.elements[i];
         if (elements[elem].ref_type)
         {
            SetDerefMatrixCodes(elem, coarse);
            NCMesh::DerefineElement(elem);
         }
         elements[elem].rank = msg.values[i];
      }
   }

   // update leaf_elements, Element::index etc.
   Update();

   UpdateLayers();

   // link old fine elements to the new coarse elements
   for (int i = 0; i < coarse.Size(); i++)
   {
      int index = elements[coarse[i]].index;
      if (element_type[index] == 0)
      {
         // this coarse element will get pruned, encode who owns it now
         index = -1 - elements[coarse[i]].rank;
      }
      transforms.embeddings[i].parent = index;
   }

   leaf_elements.Copy(old_elements);

   Prune();

   // renumber coarse element indices after pruning
   for (int i = 0; i < coarse.Size(); i++)
   {
      int &index = transforms.embeddings[i].parent;
      if (index >= 0)
      {
         index = elements[old_elements[index]].index;
      }
   }

   // make sure we can delete all send buffers
   NeighborDerefinementMessage::WaitAllSent(send_deref);
}


template<typename Type>
void ParNCMesh::SynchronizeDerefinementData(Array<Type> &elem_data,
                                            const Table &deref_table)
{
   const MPI_Datatype datatype = MPITypeMap<Type>::mpi_type;

   Array<MPI_Request*> requests;
   Array<int> neigh;

   requests.Reserve(64);
   neigh.Reserve(8);

   // make room for ghost values (indices beyond NumElements)
   elem_data.SetSize(leaf_elements.Size(), 0);

   for (int i = 0; i < deref_table.Size(); i++)
   {
      const int* fine = deref_table.GetRow(i);
      int size = deref_table.RowSize(i);
      MFEM_ASSERT(size <= 8, "");

      int ranks[8], min_rank = INT_MAX, max_rank = INT_MIN;
      for (int j = 0; j < size; j++)
      {
         ranks[j] = elements[leaf_elements[fine[j]]].rank;
         min_rank = std::min(min_rank, ranks[j]);
         max_rank = std::max(max_rank, ranks[j]);
      }

      // exchange values for derefinements that straddle processor boundaries
      if (min_rank != max_rank)
      {
         neigh.SetSize(0);
         for (int j = 0; j < size; j++)
         {
            if (ranks[j] != MyRank) { neigh.Append(ranks[j]); }
         }
         neigh.Sort();
         neigh.Unique();

         for (int j = 0; j < size; j++/*pass*/)
         {
            Type *data = &elem_data[fine[j]];

            int rnk = ranks[j], len = 1; /*j;
            do { j++; } while (j < size && ranks[j] == rnk);
            len = j - len;*/

            if (rnk == MyRank)
            {
               for (int k = 0; k < neigh.Size(); k++)
               {
                  MPI_Request* req = new MPI_Request;
                  MPI_Isend(data, len, datatype, neigh[k], 292, MyComm, req);
                  requests.Append(req);
               }
            }
            else
            {
               MPI_Request* req = new MPI_Request;
               MPI_Irecv(data, len, datatype, rnk, 292, MyComm, req);
               requests.Append(req);
            }
         }
      }
   }

   for (int i = 0; i < requests.Size(); i++)
   {
      MPI_Wait(requests[i], MPI_STATUS_IGNORE);
      delete requests[i];
   }
}

// instantiate SynchronizeDerefinementData for int, double, and float
template void
ParNCMesh::SynchronizeDerefinementData<int>(Array<int> &, const Table &);
template void
ParNCMesh::SynchronizeDerefinementData<double>(Array<double> &, const Table &);
template void
ParNCMesh::SynchronizeDerefinementData<float>(Array<float> &, const Table &);


void ParNCMesh::CheckDerefinementNCLevel(const Table &deref_table,
                                         Array<int> &level_ok, int max_nc_level)
{
   Array<int> leaf_ok(leaf_elements.Size());
   leaf_ok = 1;

   // check elements that we own
   for (int i = 0; i < deref_table.Size(); i++)
   {
      const int *fine = deref_table.GetRow(i),
                 size = deref_table.RowSize(i);

      int parent = elements[leaf_elements[fine[0]]].parent;
      Element &pa = elements[parent];

      for (int j = 0; j < size; j++)
      {
         int child = leaf_elements[fine[j]];
         if (elements[child].rank == MyRank)
         {
            int splits[3];
            CountSplits(child, splits);

            for (int k = 0; k < Dim; k++)
            {
               if ((pa.ref_type & (1 << k)) &&
                   splits[k] >= max_nc_level)
               {
                  leaf_ok[fine[j]] = 0; break;
               }
            }
         }
      }
   }

   SynchronizeDerefinementData(leaf_ok, deref_table);

   level_ok.SetSize(deref_table.Size());
   level_ok = 1;

   for (int i = 0; i < deref_table.Size(); i++)
   {
      const int* fine = deref_table.GetRow(i),
                 size = deref_table.RowSize(i);

      for (int j = 0; j < size; j++)
      {
         if (!leaf_ok[fine[j]])
         {
            level_ok[i] = 0; break;
         }
      }
   }
}


//// Rebalance /////////////////////////////////////////////////////////////////

void ParNCMesh::Rebalance(const Array<int> *custom_partition)
{
   send_rebalance_dofs.clear();
   recv_rebalance_dofs.clear();

   Array<int> old_elements;
   leaf_elements.GetSubArray(0, NElements, old_elements);

   if (!custom_partition) // SFC based partitioning
   {
      Array<int> new_ranks(leaf_elements.Size());
      new_ranks = -1;

      // figure out new assignments for Element::rank
      long local_elems = NElements, total_elems = 0;
      MPI_Allreduce(&local_elems, &total_elems, 1, MPI_LONG, MPI_SUM, MyComm);

      long first_elem_global = 0;
      MPI_Scan(&local_elems, &first_elem_global, 1, MPI_LONG, MPI_SUM, MyComm);
      first_elem_global -= local_elems;

      for (int i = 0, j = 0; i < leaf_elements.Size(); i++)
      {
         if (elements[leaf_elements[i]].rank == MyRank)
         {
            new_ranks[i] = Partition(first_elem_global + (j++), total_elems);
         }
      }

      int target_elements = PartitionFirstIndex(MyRank+1, total_elems)
                            - PartitionFirstIndex(MyRank, total_elems);

      // assign the new ranks and send elements (plus ghosts) to new owners
      RedistributeElements(new_ranks, target_elements, true);
   }
   else // whatever partitioning the user has passed
   {
      MFEM_VERIFY(custom_partition->Size() == NElements,
                  "Size of the partition array must match the number "
                  "of local mesh elements (ParMesh::GetNE()).");

      Array<int> new_ranks;
      custom_partition->Copy(new_ranks);
      new_ranks.SetSize(leaf_elements.Size(), -1); // make room for ghosts

      RedistributeElements(new_ranks, -1, true);
   }

   // set up the old index array
   old_index_or_rank.SetSize(NElements);
   old_index_or_rank = -1;
   for (int i = 0; i < old_elements.Size(); i++)
   {
      Element &el = elements[old_elements[i]];
      if (el.rank == MyRank) { old_index_or_rank[el.index] = i; }
   }

   // get rid of elements beyond the new ghost layer
   Prune();
}

void ParNCMesh::RedistributeElements(Array<int> &new_ranks, int target_elements,
                                     bool record_comm)
{
   bool sfc = (target_elements >= 0);

   UpdateLayers();

   // *** STEP 1: communicate new rank assignments for the ghost layer ***

   NeighborElementRankMessage::Map send_ghost_ranks, recv_ghost_ranks;

   ghost_layer.Sort([&](const int a, const int b)
   {
      return elements[a].rank < elements[b].rank;
   });

   {
      Array<int> rank_neighbors;

      // loop over neighbor ranks and their elements
      int begin = 0, end = 0;
      while (end < ghost_layer.Size())
      {
         // find range of elements belonging to one rank
         int rank = elements[ghost_layer[begin]].rank;
         while (end < ghost_layer.Size() &&
                elements[ghost_layer[end]].rank == rank) { end++; }

         Array<int> rank_elems;
         rank_elems.MakeRef(&ghost_layer[begin], end - begin);

         // find elements within boundary_layer that are neighbors to 'rank'
         rank_neighbors.SetSize(0);
         NeighborExpand(rank_elems, rank_neighbors, &boundary_layer);

         // send a message with new rank assignments within 'rank_neighbors'
         NeighborElementRankMessage& msg = send_ghost_ranks[rank];
         msg.SetNCMesh(this);

         msg.Reserve(rank_neighbors.Size());
         for (int i = 0; i < rank_neighbors.Size(); i++)
         {
            int elem = rank_neighbors[i];
            msg.AddElementRank(elem, new_ranks[elements[elem].index]);
         }

         msg.Isend(rank, MyComm);

         // prepare to receive a message from the neighbor too, these will
         // be new the new rank assignments for our ghost layer
         recv_ghost_ranks[rank].SetNCMesh(this);

         begin = end;
      }
   }

   NeighborElementRankMessage::RecvAll(recv_ghost_ranks, MyComm);

   // read new ranks for the ghost layer from messages received
   for (auto &kv : recv_ghost_ranks)
   {
      NeighborElementRankMessage &msg = kv.second;
      for (int i = 0; i < msg.Size(); i++)
      {
         int ghost_index = elements[msg.elements[i]].index;
         MFEM_ASSERT(element_type[ghost_index] == 2, "");
         new_ranks[ghost_index] = msg.values[i];
      }
   }

   recv_ghost_ranks.clear();

   // *** STEP 2: send elements that no longer belong to us to new assignees ***

   /* The result thus far is just the array 'new_ranks' containing new owners
      for elements that we currently own plus new owners for the ghost layer.
      Next we keep elements that still belong to us and send ElementSets with
      the remaining elements to their new owners. Each batch of elements needs
      to be sent together with their neighbors so the receiver also gets a
      ghost layer that is up to date (this is why we needed Step 1). */

   int received_elements = 0;
   for (int i = 0; i < leaf_elements.Size(); i++)
   {
      Element &el = elements[leaf_elements[i]];
      if (el.rank == MyRank && new_ranks[i] == MyRank)
      {
         received_elements++; // initialize to number of elements we're keeping
      }
      el.rank = new_ranks[i];
   }

   int nsent = 0, nrecv = 0; // for debug check

   RebalanceMessage::Map send_elems;
   {
      // sort elements we own by the new rank
      Array<int> owned_elements;
      owned_elements.MakeRef(leaf_elements.GetData(), NElements);
      owned_elements.Sort([&](const int a, const int b)
      {
         return elements[a].rank < elements[b].rank;
      });

      Array<int> batch;
      batch.Reserve(1024);

      // send elements to new owners
      int begin = 0, end = 0;
      while (end < NElements)
      {
         // find range of elements belonging to one rank
         int rank = elements[owned_elements[begin]].rank;
         while (end < owned_elements.Size() &&
                elements[owned_elements[end]].rank == rank) { end++; }

         if (rank != MyRank)
         {
            Array<int> rank_elems;
            rank_elems.MakeRef(&owned_elements[begin], end - begin);

            // expand the 'rank_elems' set by its neighbor elements (ghosts)
            batch.SetSize(0);
            NeighborExpand(rank_elems, batch);

            // send the batch
            RebalanceMessage &msg = send_elems[rank];
            msg.SetNCMesh(this);

            msg.Reserve(batch.Size());
            for (int i = 0; i < batch.Size(); i++)
            {
               int elem = batch[i];
               Element &el = elements[elem];

               if ((element_type[el.index] & 1) || el.rank != rank)
               {
                  msg.AddElementRank(elem, el.rank);
               }
               // NOTE: we skip 'ghosts' that are of the receiver's rank because
               // they are not really ghosts and would get sent multiple times,
               // disrupting the termination mechanism in Step 4.
            }

            if (sfc)
            {
               msg.Isend(rank, MyComm);
            }
            else
            {
               // custom partitioning needs synchronous sends
               msg.Issend(rank, MyComm);
            }
            nsent++;

            // also: record what elements we sent (excluding the ghosts)
            // so that SendRebalanceDofs can later send data for them
            if (record_comm)
            {
               send_rebalance_dofs[rank].SetElements(rank_elems, this);
            }
         }

         begin = end;
      }
   }

   // *** STEP 3: receive elements from others ***

   RebalanceMessage msg;
   msg.SetNCMesh(this);

   if (sfc)
   {
      /* We don't know from whom we're going to receive, so we need to probe.
         However, for the default SFC partitioning, we do know how many elements
         we're going to own eventually, so the termination condition is easy. */

      while (received_elements < target_elements)
      {
         int rank, size;
         RebalanceMessage::Probe(rank, size, MyComm);

         // receive message; note: elements are created as the message is decoded
         msg.Recv(rank, size, MyComm);
         nrecv++;

         for (int i = 0; i < msg.Size(); i++)
         {
            int elem_rank = msg.values[i];
            elements[msg.elements[i]].rank = elem_rank;

            if (elem_rank == MyRank) { received_elements++; }
         }

         // save the ranks we received from, for later use in RecvRebalanceDofs
         if (record_comm)
         {
            recv_rebalance_dofs[rank].SetNCMesh(this);
         }
      }

      Update();

      RebalanceMessage::WaitAllSent(send_elems);
   }
   else
   {
      /* The case (target_elements < 0) is used for custom partitioning.
         Here we need to employ the "non-blocking consensus" algorithm
         (https://scorec.rpi.edu/REPORTS/2015-9.pdf) to determine when the
         element exchange is finished. The algorithm uses a non-blocking
         barrier. */

      MPI_Request barrier = MPI_REQUEST_NULL;
      int done = 0;

      while (!done)
      {
         int rank, size;
         while (RebalanceMessage::IProbe(rank, size, MyComm))
         {
            // receive message; note: elements are created as the msg is decoded
            msg.Recv(rank, size, MyComm);
            nrecv++;

            for (int i = 0; i < msg.Size(); i++)
            {
               elements[msg.elements[i]].rank = msg.values[i];
            }

            // save the ranks we received from, for later use in RecvRebalanceDofs
            if (record_comm)
            {
               recv_rebalance_dofs[rank].SetNCMesh(this);
            }
         }

         if (barrier != MPI_REQUEST_NULL)
         {
            MPI_Test(&barrier, &done, MPI_STATUS_IGNORE);
         }
         else
         {
            if (RebalanceMessage::TestAllSent(send_elems))
            {
               int mpi_err = MPI_Ibarrier(MyComm, &barrier);

               MFEM_VERIFY(mpi_err == MPI_SUCCESS, "");
               MFEM_VERIFY(barrier != MPI_REQUEST_NULL, "");
            }
         }
      }

      Update();
   }

   NeighborElementRankMessage::WaitAllSent(send_ghost_ranks);

#ifdef MFEM_DEBUG
   int glob_sent, glob_recv;
   MPI_Reduce(&nsent, &glob_sent, 1, MPI_INT, MPI_SUM, 0, MyComm);
   MPI_Reduce(&nrecv, &glob_recv, 1, MPI_INT, MPI_SUM, 0, MyComm);

   if (MyRank == 0)
   {
      MFEM_ASSERT(glob_sent == glob_recv,
                  "(glob_sent, glob_recv) = ("
                  << glob_sent << ", " << glob_recv << ")");
   }
#else
   MFEM_CONTRACT_VAR(nsent);
   MFEM_CONTRACT_VAR(nrecv);
#endif
}


void ParNCMesh::SendRebalanceDofs(int old_ndofs,
                                  const Table &old_element_dofs,
                                  long old_global_offset,
                                  FiniteElementSpace *space)
{
   Array<int> dofs;
   int vdim = space->GetVDim();

   // fill messages (prepared by Rebalance) with element DOFs
   RebalanceDofMessage::Map::iterator it;
   for (it = send_rebalance_dofs.begin(); it != send_rebalance_dofs.end(); ++it)
   {
      RebalanceDofMessage &msg = it->second;
      msg.dofs.clear();
      int ne = static_cast<int>(msg.elem_ids.size());
      if (ne)
      {
         msg.dofs.reserve(old_element_dofs.RowSize(msg.elem_ids[0]) * ne * vdim);
      }
      for (int i = 0; i < ne; i++)
      {
         old_element_dofs.GetRow(msg.elem_ids[i], dofs);
         space->DofsToVDofs(dofs, old_ndofs);
         msg.dofs.insert(msg.dofs.end(), dofs.begin(), dofs.end());
      }
      msg.dof_offset = old_global_offset;
   }

   // send the DOFs to element recipients from last Rebalance()
   RebalanceDofMessage::IsendAll(send_rebalance_dofs, MyComm);
}


void ParNCMesh::RecvRebalanceDofs(Array<int> &elements, Array<long> &dofs)
{
   // receive from the same ranks as in last Rebalance()
   RebalanceDofMessage::RecvAll(recv_rebalance_dofs, MyComm);

   // count the size of the result
   int ne = 0, nd = 0;
   RebalanceDofMessage::Map::iterator it;
   for (it = recv_rebalance_dofs.begin(); it != recv_rebalance_dofs.end(); ++it)
   {
      RebalanceDofMessage &msg = it->second;
      ne += static_cast<int>(msg.elem_ids.size());
      nd += static_cast<int>(msg.dofs.size());
   }

   elements.SetSize(ne);
   dofs.SetSize(nd);

   // copy element indices and their DOFs
   ne = nd = 0;
   for (it = recv_rebalance_dofs.begin(); it != recv_rebalance_dofs.end(); ++it)
   {
      RebalanceDofMessage &msg = it->second;
      for (unsigned i = 0; i < msg.elem_ids.size(); i++)
      {
         elements[ne++] = msg.elem_ids[i];
      }
      for (unsigned i = 0; i < msg.dofs.size(); i++)
      {
         dofs[nd++] = msg.dof_offset + msg.dofs[i];
      }
   }

   RebalanceDofMessage::WaitAllSent(send_rebalance_dofs);
}


//// ElementSet ////////////////////////////////////////////////////////////////

ParNCMesh::ElementSet::ElementSet(const ElementSet &other)
   : ncmesh(other.ncmesh), include_ref_types(other.include_ref_types)
{
   other.data.Copy(data);
}

void ParNCMesh::ElementSet::WriteInt(int value)
{
   // helper to put an int to the data array
   data.Append(value & 0xff);
   data.Append((value >> 8) & 0xff);
   data.Append((value >> 16) & 0xff);
   data.Append((value >> 24) & 0xff);
}

int ParNCMesh::ElementSet::GetInt(int pos) const
{
   // helper to get an int from the data array
   return (int) data[pos] +
          ((int) data[pos+1] << 8) +
          ((int) data[pos+2] << 16) +
          ((int) data[pos+3] << 24);
}

void ParNCMesh::ElementSet::FlagElements(const Array<int> &elements, char flag)
{
   for (int i = 0; i < elements.Size(); i++)
   {
      int elem = elements[i];
      while (elem >= 0)
      {
         Element &el = ncmesh->elements[elem];
         if (el.flag == flag) { break; }
         el.flag = flag;
         elem = el.parent;
      }
   }
}

void ParNCMesh::ElementSet::EncodeTree(int elem)
{
   Element &el = ncmesh->elements[elem];
   if (!el.ref_type)
   {
      // we reached a leaf, mark this as zero child mask
      data.Append(0);
   }
   else
   {
      // check which subtrees contain marked elements
      int mask = 0;
      for (int i = 0; i < 8; i++)
      {
         if (el.child[i] >= 0 && ncmesh->elements[el.child[i]].flag)
         {
            mask |= 1 << i;
         }
      }

      // write the bit mask and visit the subtrees
      data.Append(mask);
      if (include_ref_types)
      {
         data.Append(el.ref_type);
      }

      for (int i = 0; i < 8; i++)
      {
         if (mask & (1 << i))
         {
            EncodeTree(el.child[i]);
         }
      }
   }
}

void ParNCMesh::ElementSet::Encode(const Array<int> &elements)
{
   FlagElements(elements, 1);

   // Each refinement tree that contains at least one element from the set
   // is encoded as HEADER + TREE, where HEADER is the root element number and
   // TREE is the output of EncodeTree().
   for (int i = 0; i < ncmesh->root_state.Size(); i++)
   {
      if (ncmesh->elements[i].flag)
      {
         WriteInt(i);
         EncodeTree(i);
      }
   }
   WriteInt(-1); // mark end of data

   FlagElements(elements, 0);
}

#ifdef MFEM_DEBUG
std::string ParNCMesh::ElementSet::RefPath() const
{
   std::ostringstream oss;
   for (int i = 0; i < ref_path.Size(); i++)
   {
      oss << "     elem " << ref_path[i] << " (";
      const Element &el = ncmesh->elements[ref_path[i]];
      for (int j = 0; j < GI[el.Geom()].nv; j++)
      {
         if (j) { oss << ", "; }
         oss << ncmesh->RetrieveNode(el, j);
      }
      oss << ")\n";
   }
   return oss.str();
}
#endif

void ParNCMesh::ElementSet::DecodeTree(int elem, int &pos,
                                       Array<int> &elements) const
{
#ifdef MFEM_DEBUG
   ref_path.Append(elem);
#endif
   int mask = data[pos++];
   if (!mask)
   {
      elements.Append(elem);
   }
   else
   {
      Element &el = ncmesh->elements[elem];
      if (include_ref_types)
      {
         int ref_type = data[pos++];
         if (!el.ref_type)
         {
            ncmesh->RefineElement(elem, ref_type);
         }
         else { MFEM_ASSERT(ref_type == el.ref_type, "") }
      }
      else
      {
         MFEM_ASSERT(el.ref_type != 0, "Path not found:\n"
                     << RefPath() << "     mask = " << mask);
      }

      for (int i = 0; i < 8; i++)
      {
         if (mask & (1 << i))
         {
            DecodeTree(el.child[i], pos, elements);
         }
      }
   }
#ifdef MFEM_DEBUG
   ref_path.DeleteLast();
#endif
}

void ParNCMesh::ElementSet::Decode(Array<int> &elements) const
{
   int root, pos = 0;
   while ((root = GetInt(pos)) >= 0)
   {
      pos += 4;
      DecodeTree(root, pos, elements);
   }
}

void ParNCMesh::ElementSet::Dump(std::ostream &os) const
{
   write<int>(os, data.Size());
   os.write((const char*) data.GetData(), data.Size());
}

void ParNCMesh::ElementSet::Load(std::istream &is)
{
   data.SetSize(read<int>(is));
   is.read((char*) data.GetData(), data.Size());
}


//// EncodeMeshIds/DecodeMeshIds ///////////////////////////////////////////////

void ParNCMesh::AdjustMeshIds(Array<MeshId> ids[], int rank)
{
   GetSharedVertices();
   GetSharedEdges();
   GetSharedFaces();

   if (!shared_edges.masters.Size() &&
       !shared_faces.masters.Size()) { return; }

   Array<bool> contains_rank(static_cast<int>(groups.size()));
   for (unsigned i = 0; i < groups.size(); i++)
   {
      contains_rank[i] = GroupContains(i, rank);
   }

   Array<Pair<int, int> > find_v(ids[0].Size());
   for (int i = 0; i < ids[0].Size(); i++)
   {
      find_v[i].one = ids[0][i].index;
      find_v[i].two = i;
   }
   find_v.Sort();

   // find vertices of master edges shared with 'rank', and modify their
   // MeshIds so their element/local matches the element of the master edge
   for (int i = 0; i < shared_edges.masters.Size(); i++)
   {
      const MeshId &edge_id = shared_edges.masters[i];
      if (contains_rank[entity_pmat_group[1][edge_id.index]])
      {
         int v[2], pos, k;
         GetEdgeVertices(edge_id, v);
         for (int j = 0; j < 2; j++)
         {
            if ((pos = find_v.FindSorted(Pair<int, int>(v[j], 0))) != -1)
            {
               // switch to an element/local that is safe for 'rank'
               k = find_v[pos].two;
               ChangeVertexMeshIdElement(ids[0][k], edge_id.element);
               ChangeRemainingMeshIds(ids[0], pos, find_v);
            }
         }
      }
   }

   if (!shared_faces.masters.Size()) { return; }

   Array<Pair<int, int> > find_e(ids[1].Size());
   for (int i = 0; i < ids[1].Size(); i++)
   {
      find_e[i].one = ids[1][i].index;
      find_e[i].two = i;
   }
   find_e.Sort();

   // find vertices/edges of master faces shared with 'rank', and modify their
   // MeshIds so their element/local matches the element of the master face
   for (const MeshId &face_id : shared_faces.masters)
   {
      if (contains_rank[entity_pmat_group[2][face_id.index]])
      {
         int v[4], e[4], eo[4], pos, k;
         int nfv = GetFaceVerticesEdges(face_id, v, e, eo);
         for (int j = 0; j < nfv; j++)
         {
            if ((pos = find_v.FindSorted(Pair<int, int>(v[j], 0))) != -1)
            {
               k = find_v[pos].two;
               ChangeVertexMeshIdElement(ids[0][k], face_id.element);
               ChangeRemainingMeshIds(ids[0], pos, find_v);
            }
            if ((pos = find_e.FindSorted(Pair<int, int>(e[j], 0))) != -1)
            {
               k = find_e[pos].two;
               ChangeEdgeMeshIdElement(ids[1][k], face_id.element);
               ChangeRemainingMeshIds(ids[1], pos, find_e);
            }
         }
      }
   }
}

void ParNCMesh::ChangeVertexMeshIdElement(NCMesh::MeshId &id, int elem)
{
   Element &el = elements[elem];
   MFEM_ASSERT(el.ref_type == 0, "");

   GeomInfo& gi = GI[el.Geom()];
   for (int i = 0; i < gi.nv; i++)
   {
      if (nodes[el.node[i]].vert_index == id.index)
      {
         id.local = i;
         id.element = elem;
         return;
      }
   }
   MFEM_ABORT("Vertex not found.");
}

void ParNCMesh::ChangeEdgeMeshIdElement(NCMesh::MeshId &id, int elem)
{
   Element &old = elements[id.element];
   const int *old_ev = GI[old.Geom()].edges[(int) id.local];
   Node* node = nodes.Find(old.node[old_ev[0]], old.node[old_ev[1]]);
   MFEM_ASSERT(node != NULL, "Edge not found.");

   Element &el = elements[elem];
   MFEM_ASSERT(el.ref_type == 0, "");

   GeomInfo& gi = GI[el.Geom()];
   for (int i = 0; i < gi.ne; i++)
   {
      const int* ev = gi.edges[i];
      if ((el.node[ev[0]] == node->p1 && el.node[ev[1]] == node->p2) ||
          (el.node[ev[1]] == node->p1 && el.node[ev[0]] == node->p2))
      {
         id.local = i;
         id.element = elem;
         return;
      }

   }
   MFEM_ABORT("Edge not found.");
}

void ParNCMesh::ChangeRemainingMeshIds(Array<MeshId> &ids, int pos,
                                       const Array<Pair<int, int> > &find)
{
   const MeshId &first = ids[find[pos].two];
   while (++pos < find.Size() && ids[find[pos].two].index == first.index)
   {
      MeshId &other = ids[find[pos].two];
      other.element = first.element;
      other.local = first.local;
   }
}

void ParNCMesh::EncodeMeshIds(std::ostream &os, Array<MeshId> ids[])
{
   std::map<int, int> stream_id;

   // get a list of elements involved, dump them to 'os' and create the mapping
   // element_id: (Element index -> stream ID)
   {
      Array<int> elements;
      for (int type = 0; type < 3; type++)
      {
         for (int i = 0; i < ids[type].Size(); i++)
         {
            elements.Append(ids[type][i].element);
         }
      }

      ElementSet eset(this);
      eset.Encode(elements);
      eset.Dump(os);

      Array<int> decoded;
      decoded.Reserve(elements.Size());
      eset.Decode(decoded);

      for (int i = 0; i < decoded.Size(); i++)
      {
         stream_id[decoded[i]] = i;
      }
   }

   // write the IDs as element/local pairs
   for (int type = 0; type < 3; type++)
   {
      write<int>(os, ids[type].Size());
      for (int i = 0; i < ids[type].Size(); i++)
      {
         const MeshId& id = ids[type][i];
         write<int>(os, stream_id[id.element]); // TODO: variable 1-4 bytes
         write<char>(os, id.local);
      }
   }
}

void ParNCMesh::DecodeMeshIds(std::istream &is, Array<MeshId> ids[])
{
   // read the list of elements
   ElementSet eset(this);
   eset.Load(is);

   Array<int> elems;
   eset.Decode(elems);

   // read vertex/edge/face IDs
   for (int type = 0; type < 3; type++)
   {
      int ne = read<int>(is);
      ids[type].SetSize(ne);

      for (int i = 0; i < ne; i++)
      {
         int el_num = read<int>(is);
         int elem = elems[el_num];
         Element &el = elements[elem];

         MFEM_VERIFY(!el.ref_type, "not a leaf element: " << el_num);

         MeshId &id = ids[type][i];
         id.element = elem;
         id.local = read<char>(is);

         // find vertex/edge/face index
         GeomInfo &gi = GI[el.Geom()];
         switch (type)
         {
            case 0:
            {
               id.index = nodes[el.node[(int) id.local]].vert_index;
               break;
            }
            case 1:
            {
               const int* ev = gi.edges[(int) id.local];
               Node* node = nodes.Find(el.node[ev[0]], el.node[ev[1]]);
               MFEM_ASSERT(node && node->HasEdge(), "edge not found.");
               id.index = node->edge_index;
               break;
            }
            default:
            {
               const int* fv = gi.faces[(int) id.local];
               Face* face = faces.Find(el.node[fv[0]], el.node[fv[1]],
                                       el.node[fv[2]], el.node[fv[3]]);
               MFEM_ASSERT(face, "face not found.");
               id.index = face->index;
            }
         }
      }
   }
}

void ParNCMesh::EncodeGroups(std::ostream &os, const Array<GroupId> &ids)
{
   // get a list of unique GroupIds
   std::map<GroupId, GroupId> stream_id;
   for (int i = 0; i < ids.Size(); i++)
   {
      if (i && ids[i] == ids[i-1]) { continue; }
      unsigned size = stream_id.size();
      GroupId &sid = stream_id[ids[i]];
      if (size != stream_id.size()) { sid = size; }
   }

   // write the unique groups
   write<short>(os, stream_id.size());
   for (std::map<GroupId, GroupId>::iterator
        it = stream_id.begin(); it != stream_id.end(); ++it)
   {
      write<GroupId>(os, it->second);
      if (it->first >= 0)
      {
         const CommGroup &group = groups[it->first];
         write<short>(os, group.size());
         for (unsigned i = 0; i < group.size(); i++)
         {
            write<int>(os, group[i]);
         }
      }
      else
      {
         // special "invalid" group, marks forwarded rows
         write<short>(os, -1);
      }
   }

   // write the list of all GroupIds
   write<int>(os, ids.Size());
   for (int i = 0; i < ids.Size(); i++)
   {
      write<GroupId>(os, stream_id[ids[i]]);
   }
}

void ParNCMesh::DecodeGroups(std::istream &is, Array<GroupId> &ids)
{
   int ngroups = read<short>(is);
   Array<GroupId> sgroups(ngroups);

   // read stream groups, convert to our groups
   CommGroup ranks;
   ranks.reserve(128);
   for (int i = 0; i < ngroups; i++)
   {
      int id = read<GroupId>(is);
      int size = read<short>(is);
      if (size >= 0)
      {
         ranks.resize(size);
         for (int ii = 0; ii < size; ii++)
         {
            ranks[ii] = read<int>(is);
         }
         sgroups[id] = GetGroupId(ranks);
      }
      else
      {
         sgroups[id] = -1; // forwarded
      }
   }

   // read the list of IDs
   ids.SetSize(read<int>(is));
   for (int i = 0; i < ids.Size(); i++)
   {
      ids[i] = sgroups[read<GroupId>(is)];
   }
}


//// Messages //////////////////////////////////////////////////////////////////

template<class ValueType, bool RefTypes, int Tag>
void ParNCMesh::ElementValueMessage<ValueType, RefTypes, Tag>::Encode(int)
{
   std::ostringstream ostream;

   Array<int> tmp_elements;
   tmp_elements.MakeRef(elements.data(), static_cast<int>(elements.size()));

   ElementSet eset(pncmesh, RefTypes);
   eset.Encode(tmp_elements);
   eset.Dump(ostream);

   // decode the element set to obtain a local numbering of elements
   Array<int> decoded;
   decoded.Reserve(tmp_elements.Size());
   eset.Decode(decoded);

   std::map<int, int> element_index;
   for (int i = 0; i < decoded.Size(); i++)
   {
      element_index[decoded[i]] = i;
   }

   write<int>(ostream, static_cast<int>(values.size()));
   MFEM_ASSERT(elements.size() == values.size(), "");

   for (unsigned i = 0; i < values.size(); i++)
   {
      write<int>(ostream, element_index[elements[i]]); // element number
      write<ValueType>(ostream, values[i]);
   }

   ostream.str().swap(data);
}

template<class ValueType, bool RefTypes, int Tag>
void ParNCMesh::ElementValueMessage<ValueType, RefTypes, Tag>::Decode(int)
{
   std::istringstream istream(data);

   ElementSet eset(pncmesh, RefTypes);
   eset.Load(istream);

   Array<int> tmp_elements;
   eset.Decode(tmp_elements);

   int* el = tmp_elements.GetData();
   elements.assign(el, el + tmp_elements.Size());
   values.resize(elements.size());

   int count = read<int>(istream);
   for (int i = 0; i < count; i++)
   {
      int index = read<int>(istream);
      MFEM_ASSERT(index >= 0 && (size_t) index < values.size(), "");
      values[index] = read<ValueType>(istream);
   }

   // no longer need the raw data
   data.clear();
}

void ParNCMesh::RebalanceDofMessage::SetElements(const Array<int> &elems,
                                                 NCMesh *ncmesh)
{
   eset.SetNCMesh(ncmesh);
   eset.Encode(elems);

   Array<int> decoded;
   decoded.Reserve(elems.Size());
   eset.Decode(decoded);

   elem_ids.resize(decoded.Size());
   for (int i = 0; i < decoded.Size(); i++)
   {
      elem_ids[i] = eset.GetNCMesh()->elements[decoded[i]].index;
   }
}

static void write_dofs(std::ostream &os, const std::vector<int> &dofs)
{
   write<int>(os, static_cast<int>(dofs.size()));
   // TODO: we should compress the ints, mostly they are contiguous ranges
   os.write((const char*) dofs.data(), dofs.size() * sizeof(int));
}

static void read_dofs(std::istream &is, std::vector<int> &dofs)
{
   dofs.resize(read<int>(is));
   is.read((char*) dofs.data(), dofs.size() * sizeof(int));
}

void ParNCMesh::RebalanceDofMessage::Encode(int)
{
   std::ostringstream stream;

   eset.Dump(stream);
   write<long>(stream, dof_offset);
   write_dofs(stream, dofs);

   stream.str().swap(data);
}

void ParNCMesh::RebalanceDofMessage::Decode(int)
{
   std::istringstream stream(data);

   eset.Load(stream);
   dof_offset = read<long>(stream);
   read_dofs(stream, dofs);

   data.clear();

   Array<int> elems;
   eset.Decode(elems);

   elem_ids.resize(elems.Size());
   for (int i = 0; i < elems.Size(); i++)
   {
      elem_ids[i] = eset.GetNCMesh()->elements[elems[i]].index;
   }
}


//// Utility ///////////////////////////////////////////////////////////////////

void ParNCMesh::GetDebugMesh(Mesh &debug_mesh) const
{
   // create a serial NCMesh containing all our elements (ghosts and all)
   NCMesh* copy = new NCMesh(*this);

   Array<int> &cle = copy->leaf_elements;
   for (int i = 0; i < cle.Size(); i++)
   {
      Element &el = copy->elements[cle[i]];
      el.attribute = el.rank + 1;
   }

   debug_mesh.InitFromNCMesh(*copy);
   debug_mesh.SetAttributes();
   debug_mesh.ncmesh = copy;
}

void ParNCMesh::Trim()
{
   NCMesh::Trim();

   shared_vertices.Clear();
   shared_edges.Clear();
   shared_faces.Clear();

   for (int i = 0; i < 3; i++)
   {
      entity_owner[i].DeleteAll();
      entity_pmat_group[i].DeleteAll();
      entity_index_rank[i].DeleteAll();
   }

   send_rebalance_dofs.clear();
   recv_rebalance_dofs.clear();

   old_index_or_rank.DeleteAll();

   ClearAuxPM();
}

std::size_t ParNCMesh::RebalanceDofMessage::MemoryUsage() const
{
   return (elem_ids.capacity() + dofs.capacity()) * sizeof(int);
}

template<typename K, typename V>
static std::size_t map_memory_usage(const std::map<K, V> &map)
{
   std::size_t result = 0;
   for (typename std::map<K, V>::const_iterator
        it = map.begin(); it != map.end(); ++it)
   {
      result += it->second.MemoryUsage();
      result += sizeof(std::pair<K, V>) + 3*sizeof(void*) + sizeof(bool);
   }
   return result;
}

std::size_t ParNCMesh::GroupsMemoryUsage() const
{
   std::size_t groups_size = groups.capacity() * sizeof(CommGroup);
   for (unsigned i = 0; i < groups.size(); i++)
   {
      groups_size += groups[i].capacity() * sizeof(int);
   }
   const int approx_node_size =
      sizeof(std::pair<CommGroup, GroupId>) + 3*sizeof(void*) + sizeof(bool);
   return groups_size + group_id.size() * approx_node_size;
}

template<typename Type, int Size>
static std::size_t arrays_memory_usage(const Array<Type> (&arrays)[Size])
{
   std::size_t total = 0;
   for (int i = 0; i < Size; i++)
   {
      total += arrays[i].MemoryUsage();
   }
   return total;
}

std::size_t ParNCMesh::MemoryUsage(bool with_base) const
{
   return (with_base ? NCMesh::MemoryUsage() : 0) +
          GroupsMemoryUsage() +
          arrays_memory_usage(entity_owner) +
          arrays_memory_usage(entity_pmat_group) +
          arrays_memory_usage(entity_conf_group) +
          arrays_memory_usage(entity_elem_local) +
          shared_vertices.MemoryUsage() +
          shared_edges.MemoryUsage() +
          shared_faces.MemoryUsage() +
          face_orient.MemoryUsage() +
          element_type.MemoryUsage() +
          ghost_layer.MemoryUsage() +
          boundary_layer.MemoryUsage() +
          tmp_owner.MemoryUsage() +
          tmp_shared_flag.MemoryUsage() +
          arrays_memory_usage(entity_index_rank) +
          tmp_neighbors.MemoryUsage() +
          map_memory_usage(send_rebalance_dofs) +
          map_memory_usage(recv_rebalance_dofs) +
          old_index_or_rank.MemoryUsage() +
          aux_pm_store.MemoryUsage() +
          sizeof(ParNCMesh) - sizeof(NCMesh);
}

int ParNCMesh::PrintMemoryDetail(bool with_base) const
{
   if (with_base) { NCMesh::PrintMemoryDetail(); }

   mfem::out << GroupsMemoryUsage() << " groups\n"
             << arrays_memory_usage(entity_owner) << " entity_owner\n"
             << arrays_memory_usage(entity_pmat_group) << " entity_pmat_group\n"
             << arrays_memory_usage(entity_conf_group) << " entity_conf_group\n"
             << arrays_memory_usage(entity_elem_local) << " entity_elem_local\n"
             << shared_vertices.MemoryUsage() << " shared_vertices\n"
             << shared_edges.MemoryUsage() << " shared_edges\n"
             << shared_faces.MemoryUsage() << " shared_faces\n"
             << face_orient.MemoryUsage() << " face_orient\n"
             << element_type.MemoryUsage() << " element_type\n"
             << ghost_layer.MemoryUsage() << " ghost_layer\n"
             << boundary_layer.MemoryUsage() << " boundary_layer\n"
             << tmp_owner.MemoryUsage() << " tmp_owner\n"
             << tmp_shared_flag.MemoryUsage() << " tmp_shared_flag\n"
             << arrays_memory_usage(entity_index_rank) << " entity_index_rank\n"
             << tmp_neighbors.MemoryUsage() << " tmp_neighbors\n"
             << map_memory_usage(send_rebalance_dofs) << " send_rebalance_dofs\n"
             << map_memory_usage(recv_rebalance_dofs) << " recv_rebalance_dofs\n"
             << old_index_or_rank.MemoryUsage() << " old_index_or_rank\n"
             << aux_pm_store.MemoryUsage() << " aux_pm_store\n"
             << sizeof(ParNCMesh) - sizeof(NCMesh) << " ParNCMesh" << std::endl;

   return leaf_elements.Size();
}

void ParNCMesh::GetGhostElements(Array<int> & gelem)
{
   gelem.SetSize(NGhostElements);

   for (int g=0; g<NGhostElements; ++g)
   {
      // This is an index in NCMesh::elements, an array of all elements, cf.
      // NCMesh::OnMeshUpdated.
      gelem[g] = leaf_elements[NElements + g];
   }
}

// Note that this function is modeled after ParNCMesh::Refine().
void ParNCMesh::CommunicateGhostData(
   const Array<VarOrderElemInfo> & sendData, Array<VarOrderElemInfo> & recvData)
{
   recvData.SetSize(0);

   if (NRanks == 1) { return; }

   NeighborPRefinementMessage::Map send_ref;

   // create refinement messages to all neighbors (NOTE: some may be empty)
   Array<int> neighbors;
   NeighborProcessors(neighbors);
   for (int i = 0; i < neighbors.Size(); i++)
   {
      send_ref[neighbors[i]].SetNCMesh(this);
   }

   // populate messages: all refinements that occur next to the processor
   // boundary need to be sent to the adjoining neighbors so they can keep
   // their ghost layer up to date
   Array<int> ranks;
   ranks.Reserve(64);
   for (int i = 0; i < sendData.Size(); i++)
   {
      MFEM_ASSERT(sendData[i].element < (unsigned int) NElements, "");
      const int elem = leaf_elements[sendData[i].element];
      ElementNeighborProcessors(elem, ranks);
      for (int j = 0; j < ranks.Size(); j++)
      {
         send_ref[ranks[j]].AddRefinement(elem, sendData[i].order);
      }
   }

   // send the messages (overlap with local refinements)
   NeighborPRefinementMessage::IsendAll(send_ref, MyComm);

   // receive (ghost layer) refinements from all neighbors
   for (int j = 0; j < neighbors.Size(); j++)
   {
      int rank, size;
      NeighborPRefinementMessage::Probe(rank, size, MyComm);

      NeighborPRefinementMessage msg;
      msg.SetNCMesh(this);
      msg.Recv(rank, size, MyComm);

      // Get the ghost refinement data
      const int os = recvData.Size();
      recvData.SetSize(os + msg.Size());
      for (int i = 0; i < msg.Size(); i++)
      {
         recvData[os + i].element = msg.elements[i];
         recvData[os + i].order = msg.values[i];
      }
   }

   // make sure we can delete the send buffers
   NeighborPRefinementMessage::WaitAllSent(send_ref);
}

} // namespace mfem

#endif // MFEM_USE_MPI
